diff --git a/orangecontrib/text/widgets/owscoredocuments.py b/orangecontrib/text/widgets/owscoredocuments.py index d50b7b5c0..c46650654 100644 --- a/orangecontrib/text/widgets/owscoredocuments.py +++ b/orangecontrib/text/widgets/owscoredocuments.py @@ -1,20 +1,38 @@ import re from collections import Counter +from contextlib import contextmanager from inspect import signature from typing import Callable, List, Tuple, Union import numpy as np -from AnyQt.QtCore import QSortFilterProxyModel, Qt -from AnyQt.QtWidgets import QHeaderView, QLineEdit, QTableView +from AnyQt.QtCore import ( + QItemSelection, + QItemSelectionModel, + QSortFilterProxyModel, + Qt, + Signal, +) +from AnyQt.QtWidgets import ( + QButtonGroup, + QGridLayout, + QHeaderView, + QLineEdit, + QRadioButton, + QTableView, +) +from pandas import isnull +from sklearn.metrics.pairwise import cosine_similarity + +# todo: uncomment when minimum version of Orange is 3.29.2 +# from orangecanvas.gui.utils import disconnected +from orangewidget import gui from Orange.data import ContinuousVariable, Domain, StringVariable, Table from Orange.util import wrap_callback +from Orange.widgets.settings import ContextSetting, PerfectDomainContextHandler, Setting +from Orange.widgets.utils.annotated_data import create_annotated_table from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState from Orange.widgets.utils.itemmodels import PyTableModel, TableModel from Orange.widgets.widget import Input, Msg, Output, OWWidget -from orangewidget import gui -from orangewidget.settings import Setting -from pandas import isnull -from sklearn.metrics.pairwise import cosine_similarity from orangecontrib.text import Corpus from orangecontrib.text.preprocess import BaseNormalizer, BaseTransformer @@ -24,6 +42,16 @@ ) +# todo: remove when minimum version of Orange is 3.29.2 +@contextmanager +def disconnected(signal, slot, type=Qt.UniqueConnection): + signal.disconnect(slot) + try: + yield + finally: + signal.connect(slot, type) + + def _word_frequency(corpus: Corpus, words: List[str], callback: Callable) -> np.ndarray: res = [] tokens = corpus.tokens @@ -183,12 +211,20 @@ def callback(i: float) -> None: state.set_partial_result((sm, aggregation, scs)) +class SelectionMethods: + NONE, ALL, MANUAL, N_BEST = range(4) + ITEMS = "None", "All", "Manual", "Top documents" + + class ScoreDocumentsTableView(QTableView): + pressedAny = Signal() + def __init__(self): super().__init__( sortingEnabled=True, editTriggers=QTableView.NoEditTriggers, - selectionMode=QTableView.NoSelection, + selectionMode=QTableView.ExtendedSelection, + selectionBehavior=QTableView.SelectRows, cornerButtonEnabled=False, ) self.setItemDelegate(gui.ColoredBarItemDelegate(self)) @@ -215,6 +251,10 @@ def update_column_widths(self) -> None: # document title column is one that stretch header.setSectionResizeMode(0, QHeaderView.Stretch) + def mousePressEvent(self, event): + super().mousePressEvent(event) + self.pressedAny.emit() + class ScoreDocumentsProxyModel(QSortFilterProxyModel): @staticmethod @@ -253,23 +293,31 @@ class OWScoreDocuments(OWWidget, ConcurrentWidgetMixin): icon = "icons/ScoreDocuments.svg" priority = 500 + buttons_area_orientation = Qt.Vertical + # default order - table sorted in input order DEFAULT_SORTING = (-1, Qt.AscendingOrder) + settingsHandler = PerfectDomainContextHandler() auto_commit: bool = Setting(True) aggregation: int = Setting(0) word_frequency: bool = Setting(True) word_appearance: bool = Setting(False) embedding_similarity: bool = Setting(False) + embedding_language: int = Setting(0) + sort_column_order: Tuple[int, int] = Setting(DEFAULT_SORTING) - embedding_language = Setting(0) + selected_rows: List[int] = ContextSetting([], schema_only=True) + sel_method: int = ContextSetting(SelectionMethods.N_BEST) + n_selected: int = ContextSetting(3) class Inputs: corpus = Input("Corpus", Corpus) words = Input("Words", Table) class Outputs: + selected_documents = Output("Selected documents", Corpus, default=True) corpus = Output("Corpus", Corpus) class Warning(OWWidget.Warning): @@ -322,6 +370,33 @@ def _setup_control_area(self) -> None: ) gui.rubber(self.controlArea) + + # select words box + box = gui.vBox(self.buttonsArea, "Select Documents") + grid = QGridLayout() + grid.setContentsMargins(0, 0, 0, 0) + + self._sel_method_buttons = QButtonGroup() + for method, label in enumerate(SelectionMethods.ITEMS): + button = QRadioButton(label) + button.setChecked(method == self.sel_method) + grid.addWidget(button, method, 0) + self._sel_method_buttons.addButton(button, method) + self._sel_method_buttons.buttonClicked[int].connect(self.__set_selection_method) + + spin = gui.spin( + box, + self, + "n_selected", + 1, + 999, + addToLayout=False, + callback=lambda: self.__set_selection_method(SelectionMethods.N_BEST), + ) + grid.addWidget(spin, 3, 1) + box.layout().addLayout(grid) + + # autocommit gui.auto_send(self.buttonsArea, self, "auto_commit") def _setup_main_area(self) -> None: @@ -333,7 +408,11 @@ def _setup_main_area(self) -> None: self.model = model = ScoreDocumentsTableModel(parent=self) model.setHorizontalHeaderLabels(["Document"]) + def select_manual(): + self.__set_selection_method(SelectionMethods.MANUAL) + self.view = view = ScoreDocumentsTableView() + view.pressedAny.connect(select_manual) self.mainArea.layout().addWidget(view) # by default data are sorted in the Table order header = self.view.horizontalHeader() @@ -344,6 +423,7 @@ def _setup_main_area(self) -> None: proxy_model.setFilterCaseSensitivity(False) view.setModel(proxy_model) view.model().setSourceModel(self.model) + self.view.selectionModel().selectionChanged.connect(self.__on_selection_change) def __on_filter_changed(self) -> None: model = self.view.model() @@ -352,15 +432,39 @@ def __on_filter_changed(self) -> None: def __on_horizontal_header_clicked(self, index: int): header = self.view.horizontalHeader() self.sort_column_order = (index, header.sortIndicatorOrder()) + self._select_rows() + # when sorting change output table must consider the new order + # call explicitly since selection in table is not changed + if ( + self.sel_method == SelectionMethods.MANUAL + and self.selected_rows + or self.sel_method == SelectionMethods.ALL + ): + # retrieve selection in new order + self.selected_rows = self.get_selected_indices() + self._send_output() + + def __on_selection_change(self): + self.selected_rows = self.get_selected_indices() + self._send_output() + + def __set_selection_method(self, method: int): + self.sel_method = method + self._sel_method_buttons.button(method).setChecked(True) + self._select_rows() @Inputs.corpus def set_data(self, corpus: Corpus) -> None: + self.closeContext() self.Warning.corpus_not_normalized.clear() if corpus is not None: self.Warning.missing_corpus.clear() if not self._is_corpus_normalized(corpus): self.Warning.corpus_not_normalized() self.corpus = corpus + self.selected_rows = [] + self.openContext(corpus) + self._sel_method_buttons.button(self.sel_method).setChecked(True) self._clear_and_run() @staticmethod @@ -416,10 +520,16 @@ def _gather_scores(self) -> Tuple[np.ndarray, List[str]]: labels = [SCORING_METHODS[m][0] for m in methods] return scores, labels - def _send_output(self, scores: np.ndarray, labels: List[str]) -> None: + def _send_output(self) -> None: """ Create corpus with scores and output it """ + if self.corpus is None: + self.Outputs.corpus.send(None) + self.Outputs.selected_documents.send(None) + return + + scores, labels = self._gather_scores() if labels: d = self.corpus.domain domain = Domain( @@ -427,39 +537,54 @@ def _send_output(self, scores: np.ndarray, labels: List[str]) -> None: d.class_var, metas=d.metas + tuple(ContinuousVariable(l) for l in labels), ) - corpus = Corpus( + out_corpus = Corpus( domain, self.corpus.X, self.corpus.Y, np.hstack([self.corpus.metas, scores]), ) - Corpus.retain_preprocessing(self.corpus, corpus) - self.Outputs.corpus.send(corpus) - elif self.corpus is not None: - self.Outputs.corpus.send(self.corpus) + Corpus.retain_preprocessing(self.corpus, out_corpus) else: - self.Outputs.corpus.send(None) + out_corpus = self.corpus - def _fill_table(self, scores: np.ndarray, labels: List[str]) -> None: + self.Outputs.corpus.send(create_annotated_table(out_corpus, self.selected_rows)) + self.Outputs.selected_documents.send( + out_corpus[self.selected_rows] if self.selected_rows else None + ) + + def _fill_table(self) -> None: """ Fill the table in the widget with scores and document names """ if self.corpus is None: self.model.clear() return + scores, labels = self._gather_scores() labels = ["Document"] + labels titles = self.corpus.titles.tolist() + + # clearing selection and sorting to prevent SEGFAULT on model.wrap + self.view.horizontalHeader().setSortIndicator(-1, Qt.AscendingOrder) + with disconnected( + self.view.selectionModel().selectionChanged, self.__on_selection_change + ): + self.view.clearSelection() + self.model.wrap([[c] + s for c, s in zip(titles, scores.tolist())]) self.model.setHorizontalHeaderLabels(labels) self.view.update_column_widths() + if self.model.columnCount() > self.sort_column_order[0]: + # if not enough columns do not apply sorting from settings since + # sorting can besaved for score column while scores are still computing + # tables is filled before scores are computed with document names + self.view.horizontalHeader().setSortIndicator(*self.sort_column_order) - self.view.horizontalHeader().setSortIndicator(*self.sort_column_order) + self._select_rows() def _fill_and_output(self) -> None: """Fill the table in the widget and send the output""" - scores, labels = self._gather_scores() - self._fill_table(scores, labels) - self._send_output(scores, labels) + self._fill_table() + self._send_output() def _clear_and_run(self) -> None: """Clear cached scores and commit""" @@ -500,14 +625,12 @@ def commit(self) -> None: self._fill_and_output() def on_done(self, _: None) -> None: - scores, labels = self._gather_scores() - self._send_output(scores, labels) + self._send_output() def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None: sc_method, aggregation, scores = result self.scores[(sc_method, aggregation)] = scores - scores, labels = self._gather_scores() - self._fill_table(scores, labels) + self._fill_table() def on_exception(self, ex: Exception) -> None: self.Error.custom_err(ex) @@ -543,6 +666,49 @@ def _is_corpus_normalized(corpus: Corpus) -> bool: for pp in corpus.used_preprocessor.preprocessors ) + def get_selected_indices(self) -> List[int]: + # get indices in table's order - that the selected output table have same order + selected_rows = sorted( + self.view.selectionModel().selectedRows(), key=lambda idx: idx.row() + ) + return [self.view.model().mapToSource(r).row() for r in selected_rows] + + def _select_rows(self): + proxy_model = self.view.model() + n_rows, n_columns = proxy_model.rowCount(), proxy_model.columnCount() + if self.sel_method == SelectionMethods.NONE: + selection = QItemSelection() + elif self.sel_method == SelectionMethods.ALL: + selection = QItemSelection( + proxy_model.index(0, 0), proxy_model.index(n_rows - 1, n_columns - 1) + ) + elif self.sel_method == SelectionMethods.MANUAL: + selection = QItemSelection() + new_sel = [] + for row in self.selected_rows: + if row < n_rows: + new_sel.append(row) + _selection = QItemSelection( + self.model.index(row, 0), self.model.index(row, n_columns - 1) + ) + selection.merge( + proxy_model.mapSelectionFromSource(_selection), + QItemSelectionModel.Select, + ) + # selected rows must be updated when the same dataset with less rows + # appear at the input - it is not handled by selectionChanged + # in cases when all selected rows missing in new table + self.selected_rows = new_sel + elif self.sel_method == SelectionMethods.N_BEST: + n_sel = min(self.n_selected, n_rows) + selection = QItemSelection( + proxy_model.index(0, 0), proxy_model.index(n_sel - 1, n_columns - 1) + ) + else: + raise NotImplementedError + + self.view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect) + if __name__ == "__main__": from orangewidget.utils.widgetpreview import WidgetPreview diff --git a/orangecontrib/text/widgets/tests/test_owscoredocuments.py b/orangecontrib/text/widgets/tests/test_owscoredocuments.py index 9d8d4fc20..bbd958b63 100644 --- a/orangecontrib/text/widgets/tests/test_owscoredocuments.py +++ b/orangecontrib/text/widgets/tests/test_owscoredocuments.py @@ -4,7 +4,8 @@ from unittest.mock import patch import numpy as np -from AnyQt.QtCore import Qt +from AnyQt.QtCore import QItemSelectionModel, Qt + from Orange.data import ContinuousVariable, Domain, StringVariable, Table from Orange.misc.collections import natural_sorted from Orange.util import dummy_callback @@ -15,6 +16,7 @@ from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder from orangecontrib.text.widgets.owscoredocuments import ( OWScoreDocuments, + SelectionMethods, _preprocess_words, ) @@ -74,8 +76,9 @@ def test_set_data(self): output = self.get_output(self.widget.Outputs.corpus) self.assertTupleEqual(output.domain.variables, self.corpus.domain.variables) - self.assertTupleEqual(output.domain.metas[:-1], self.corpus.domain.metas) - self.assertEqual(str(output.domain.metas[-1]), "Word count") + self.assertTupleEqual(output.domain.metas[:1], self.corpus.domain.metas) + self.assertEqual(str(output.domain.metas[1]), "Word count") + self.assertEqual(str(output.domain.metas[2]), "Selected") self.assertEqual(len(output), len(self.corpus)) def test_corpus_not_normalized(self): @@ -352,11 +355,14 @@ def test_sort_setting(self): """ view = self.widget.view model = view.model() - self.widget.sort_column_order = (1, Qt.DescendingOrder) self.send_signal(self.widget.Inputs.corpus, self.corpus) self.send_signal(self.widget.Inputs.words, self.words) self.wait_until_finished() + self.widget.sort_column_order = (1, Qt.DescendingOrder) + self.widget._fill_table() + self.wait_until_finished() + header = self.widget.view.horizontalHeader() current_sorting = (header.sortIndicatorSection(), header.sortIndicatorOrder()) data = [model.data(model.index(i, 1)) for i in range(model.rowCount())] @@ -385,6 +391,49 @@ def test_sort_setting(self): self.assertTupleEqual((1, Qt.DescendingOrder), current_sorting) self.assertListEqual(sorted(data, reverse=True), data) + def test_selection_none(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.widget._sel_method_buttons.button(SelectionMethods.NONE).click() + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertIsNone(output) + + def tests_selection_all(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.widget._sel_method_buttons.button(SelectionMethods.ALL).click() + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertEqual(len(self.corpus), len(output)) + + def test_selection_manual(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.widget._sel_method_buttons.button(SelectionMethods.MANUAL).click() + + mode = QItemSelectionModel.Rows | QItemSelectionModel.Select + view = self.widget.view + view.clearSelection() + model = view.model() + view.selectionModel().select(model.index(2, 0), mode) + view.selectionModel().select(model.index(3, 0), mode) + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual(["Document 3", "Document 4"], output.titles.tolist()) + + def test_selection_n_best(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.widget._sel_method_buttons.button(SelectionMethods.N_BEST).click() + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual( + [f"Document {i}" for i in range(1, 4)], output.titles.tolist() + ) + + self.widget.controls.n_selected.setValue(5) + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual( + [f"Document {i}" for i in range(1, 6)], output.titles.tolist() + ) + if __name__ == "__main__": unittest.main()