diff --git a/orangecontrib/text/widgets/owscoredocuments.py b/orangecontrib/text/widgets/owscoredocuments.py index b657d22db..76f914056 100644 --- a/orangecontrib/text/widgets/owscoredocuments.py +++ b/orangecontrib/text/widgets/owscoredocuments.py @@ -1,25 +1,33 @@ import re from collections import Counter from inspect import signature -from typing import List, Callable, Tuple, Union +from typing import Callable, List, Tuple, Union import numpy as np -from pandas import isnull -from Orange.data import ( - Table, - Domain, - StringVariable, - ContinuousVariable, - DiscreteVariable, +from AnyQt.QtCore import ( + QItemSelection, + QItemSelectionModel, + QSortFilterProxyModel, + Qt, + Signal, +) +from AnyQt.QtWidgets import ( + QButtonGroup, + QGridLayout, + QHeaderView, + QLineEdit, + QRadioButton, + QTableView, ) +from Orange.data import ContinuousVariable, Domain, StringVariable, Table from Orange.util import wrap_callback +from Orange.widgets.settings import ContextSetting, DomainContextHandler, Setting +from Orange.widgets.utils.annotated_data import create_annotated_table from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState +from Orange.widgets.utils.itemmodels import PyTableModel, TableModel from Orange.widgets.widget import Input, Msg, Output, OWWidget from orangewidget import gui -from orangewidget.settings import Setting -from Orange.widgets.utils.itemmodels import PyTableModel, TableModel -from AnyQt.QtWidgets import QTableView, QLineEdit, QHeaderView -from AnyQt.QtCore import Qt, QSortFilterProxyModel +from pandas import isnull from sklearn.metrics.pairwise import cosine_similarity from orangecontrib.text import Corpus @@ -30,9 +38,7 @@ ) -def _word_frequency( - corpus: Corpus, words: List[str], callback: Callable -) -> np.ndarray: +def _word_frequency(corpus: Corpus, words: List[str], callback: Callable) -> np.ndarray: res = [] tokens = corpus.tokens for i, t in enumerate(tokens): @@ -163,33 +169,35 @@ def callback(i: float) -> None: cb_part = 1 / (len(scoring_methods) + 1) # +1 for preprocessing - words = _preprocess_words( - corpus, words, wrap_callback(callback, end=cb_part) - ) + words = _preprocess_words(corpus, words, wrap_callback(callback, end=cb_part)) for i, sm in enumerate(scoring_methods): scoring_method = SCORING_METHODS[sm][1] sig = signature(scoring_method) - add_params = { - k: v for k, v in additional_params.items() if k in sig.parameters - } + add_params = {k: v for k, v in additional_params.items() if k in sig.parameters} scs = scoring_method( corpus, words, - wrap_callback( - callback, start=(i + 1) * cb_part, end=(i + 2) * cb_part - ), + wrap_callback(callback, start=(i + 1) * cb_part, end=(i + 2) * cb_part), **add_params ) scs = AGGREGATIONS[aggregation](scs, axis=1) state.set_partial_result((sm, aggregation, scs)) +class SelectionMethods: + NONE, ALL, MANUAL, N_BEST = range(4) + ITEMS = "None", "All", "Manual", "Top words" + + class ScoreDocumentsTableView(QTableView): + pressedAny = Signal() + def __init__(self): super().__init__( sortingEnabled=True, editTriggers=QTableView.NoEditTriggers, - selectionMode=QTableView.NoSelection, + selectionMode=QTableView.ExtendedSelection, + selectionBehavior=QTableView.SelectRows, cornerButtonEnabled=False, ) self.setItemDelegate(gui.ColoredBarItemDelegate(self)) @@ -202,7 +210,8 @@ def update_column_widths(self) -> None: """ header = self.horizontalHeader() col_width = max( - [0] + [ + [0] + + [ max(self.sizeHintForColumn(i), header.sectionSizeHint(i)) for i in range(1, self.model().columnCount()) ] @@ -215,6 +224,10 @@ def update_column_widths(self) -> None: # document title column is one that stretch header.setSectionResizeMode(0, QHeaderView.Stretch) + def mousePressEvent(self, event): + super().mousePressEvent(event) + self.pressedAny.emit() + class ScoreDocumentsProxyModel(QSortFilterProxyModel): @staticmethod @@ -223,10 +236,7 @@ def _convert(text: str) -> Union[str, int]: @staticmethod def _alphanum_key(key: str) -> List[Union[str, int]]: - return [ - ScoreDocumentsProxyModel._convert(c) - for c in re.split("([0-9]+)", key) - ] + return [ScoreDocumentsProxyModel._convert(c) for c in re.split("([0-9]+)", key)] def lessThan(self, left_ind, right_ind): """ @@ -259,20 +269,26 @@ class OWScoreDocuments(OWWidget, ConcurrentWidgetMixin): # default order - table sorted in input order DEFAULT_SORTING = (-1, Qt.AscendingOrder) + settingsHandler = DomainContextHandler() auto_commit: bool = Setting(True) aggregation: int = Setting(0) word_frequency: bool = Setting(True) word_appearance: bool = Setting(False) embedding_similarity: bool = Setting(False) + embedding_language: int = Setting(0) + sort_column_order: Tuple[int, int] = Setting(DEFAULT_SORTING) - embedding_language = Setting(0) + selected_rows: List[int] = ContextSetting([], schema_only=True) + sel_method: int = ContextSetting(SelectionMethods.N_BEST) + n_selected: int = ContextSetting(3) class Inputs: corpus = Input("Corpus", Corpus) words = Input("Words", Table) class Outputs: + selected_documents = Output("Selected documents", Corpus, default=True) corpus = Output("Corpus", Corpus) class Warning(OWWidget.Warning): @@ -327,6 +343,33 @@ def _setup_control_area(self) -> None: gui.rubber(self.controlArea) gui.auto_send(self.buttonsArea, self, "auto_commit") + # select words box + box = gui.vBox(self.controlArea, "Select Words") + grid = QGridLayout() + grid.setContentsMargins(0, 0, 0, 0) + box.layout().addLayout(grid) + + self._sel_method_buttons = QButtonGroup() + for method, label in enumerate(SelectionMethods.ITEMS): + button = QRadioButton(label) + button.setChecked(method == self.sel_method) + grid.addWidget(button, method, 0) + self._sel_method_buttons.addButton(button, method) + self._sel_method_buttons.buttonClicked[int].connect( + self.__set_selection_method + ) + + spin = gui.spin( + box, + self, + "n_selected", + 1, + 999, + addToLayout=False, + callback=lambda: self.__set_selection_method(SelectionMethods.N_BEST), + ) + grid.addWidget(spin, 3, 1) + def _setup_main_area(self) -> None: self._filter_line_edit = QLineEdit( textChanged=self.__on_filter_changed, placeholderText="Filter..." @@ -336,7 +379,11 @@ def _setup_main_area(self) -> None: self.model = model = ScoreDocumentsTableModel(parent=self) model.setHorizontalHeaderLabels(["Document"]) + def select_manual(): + self.__set_selection_method(SelectionMethods.MANUAL) + self.view = view = ScoreDocumentsTableView() + view.pressedAny.connect(select_manual) self.mainArea.layout().addWidget(view) # by default data are sorted in the Table order header = self.view.horizontalHeader() @@ -347,6 +394,7 @@ def _setup_main_area(self) -> None: proxy_model.setFilterCaseSensitivity(False) view.setModel(proxy_model) view.model().setSourceModel(self.model) + self.view.selectionModel().selectionChanged.connect(self.__on_selection_change) def __on_filter_changed(self) -> None: model = self.view.model() @@ -355,15 +403,38 @@ def __on_filter_changed(self) -> None: def __on_horizontal_header_clicked(self, index: int): header = self.view.horizontalHeader() self.sort_column_order = (index, header.sortIndicatorOrder()) + self._select_rows() + # when sorting change output table must consider the new order + # call explicitly since selection in table is not changed + if ( + self.sel_method == SelectionMethods.MANUAL + and self.selected_rows + or self.sel_method == SelectionMethods.ALL + ): + # retrieve selection in new order + self.selected_rows = self.get_selected_indices() + self._send_output() + + def __on_selection_change(self): + self.selected_rows = self.get_selected_indices() + self._send_output() + + def __set_selection_method(self, method: int): + self.sel_method = method + self._sel_method_buttons.button(method).setChecked(True) + self._select_rows() @Inputs.corpus def set_data(self, corpus: Corpus) -> None: + self.closeContext() self.Warning.corpus_not_normalized.clear() if corpus is not None: self.Warning.missing_corpus.clear() if not self._is_corpus_normalized(corpus): self.Warning.corpus_not_normalized() self.corpus = corpus + self.openContext(corpus) + self._sel_method_buttons.button(self.sel_method).setChecked(True) self._clear_and_run() @staticmethod @@ -391,10 +462,7 @@ def avg_len(attr): @Inputs.words def set_words(self, words: Table) -> None: - if ( - words is None - or len(words.domain.variables + words.domain.metas) == 0 - ): + if words is None or len(words.domain.variables + words.domain.metas) == 0: self.words = None else: self.Warning.missing_words.clear() @@ -418,18 +486,20 @@ def _gather_scores(self) -> Tuple[np.ndarray, List[str]]: scorers = self._get_active_scorers() methods = [m for m in scorers if (m, aggregation) in self.scores] scores = [self.scores[(m, aggregation)] for m in methods] - scores = ( - np.column_stack(scores) - if scores - else np.empty((len(self.corpus), 0)) - ) + scores = np.column_stack(scores) if scores else np.empty((len(self.corpus), 0)) labels = [SCORING_METHODS[m][0] for m in methods] return scores, labels - def _send_output(self, scores: np.ndarray, labels: List[str]) -> None: + def _send_output(self) -> None: """ Create corpus with scores and output it """ + if self.corpus is None: + self.Outputs.corpus.send(None) + self.Outputs.selected_documents.send(None) + return + + scores, labels = self._gather_scores() if labels: d = self.corpus.domain domain = Domain( @@ -437,42 +507,45 @@ def _send_output(self, scores: np.ndarray, labels: List[str]) -> None: d.class_var, metas=d.metas + tuple(ContinuousVariable(l) for l in labels), ) - corpus = Corpus( + out_corpus = Corpus( domain, self.corpus.X, self.corpus.Y, np.hstack([self.corpus.metas, scores]), ) - Corpus.retain_preprocessing(self.corpus, corpus) - self.Outputs.corpus.send(corpus) - elif self.corpus is not None: - self.Outputs.corpus.send(self.corpus) + Corpus.retain_preprocessing(self.corpus, out_corpus) else: - self.Outputs.corpus.send(None) + out_corpus = self.corpus + + self.Outputs.corpus.send(create_annotated_table(out_corpus, self.selected_rows)) + self.Outputs.selected_documents.send( + out_corpus[self.selected_rows] if self.selected_rows else None + ) - def _fill_table(self, scores: np.ndarray, labels: List[str]) -> None: + def _fill_table(self) -> None: """ Fill the table in the widget with scores and document names """ if self.corpus is None: self.model.clear() return + scores, labels = self._gather_scores() labels = ["Document"] + labels titles = self.corpus.titles.tolist() + self.model.wrap([[c] + s for c, s in zip(titles, scores.tolist())]) self.model.setHorizontalHeaderLabels(labels) self.view.update_column_widths() - self.view.horizontalHeader().setSortIndicator(*self.sort_column_order) + self._select_rows() def _fill_and_output(self) -> None: - """ Fill the table in the widget and send the output """ - scores, labels = self._gather_scores() - self._fill_table(scores, labels) - self._send_output(scores, labels) + """Fill the table in the widget and send the output""" + self._fill_table() + self._send_output() def _clear_and_run(self) -> None: - """ Clear cached scores and commit """ + """Clear cached scores and commit""" self.scores = {} self.cancel() self._fill_and_output() @@ -493,9 +566,7 @@ def commit(self) -> None: else: scorers = self._get_active_scorers() aggregation = self._get_active_aggregation() - new_scores = [ - s for s in scorers if (s, aggregation) not in self.scores - ] + new_scores = [s for s in scorers if (s, aggregation) not in self.scores] if new_scores: self.start( _run, @@ -512,14 +583,12 @@ def commit(self) -> None: self._fill_and_output() def on_done(self, _: None) -> None: - scores, labels = self._gather_scores() - self._send_output(scores, labels) + self._send_output() def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None: sc_method, aggregation, scores = result self.scores[(sc_method, aggregation)] = scores - scores, labels = self._gather_scores() - self._fill_table(scores, labels) + self._fill_table() def on_exception(self, ex: Exception) -> None: self.Error.unknown_err(ex) @@ -554,6 +623,40 @@ def _is_corpus_normalized(corpus: Corpus) -> bool: for pp in corpus.used_preprocessor.preprocessors ) + def get_selected_indices(self) -> List[int]: + # get indices in table's order - that the selected output table have same order + selected_rows = sorted( + self.view.selectionModel().selectedRows(), key=lambda idx: idx.row() + ) + return [self.view.model().mapToSource(r).row() for r in selected_rows] + + def _select_rows(self): + model = self.view.model() + n_rows, n_columns = model.rowCount(), model.columnCount() + if self.sel_method == SelectionMethods.NONE: + selection = QItemSelection() + elif self.sel_method == SelectionMethods.ALL: + selection = QItemSelection( + model.index(0, 0), model.index(n_rows - 1, n_columns - 1) + ) + elif self.sel_method == SelectionMethods.MANUAL: + selection = QItemSelection() + for row in self.selected_rows: + _selection = QItemSelection( + model.mapFromSource(self.model.index(row, 0)), + model.mapFromSource(self.model.index(row, n_columns - 1)), + ) + selection.merge(_selection, QItemSelectionModel.Select) + elif self.sel_method == SelectionMethods.N_BEST: + n_sel = min(self.n_selected, n_rows) + selection = QItemSelection( + model.index(0, 0), model.index(n_sel - 1, n_columns - 1) + ) + else: + raise NotImplementedError + + self.view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect) + if __name__ == "__main__": from orangewidget.utils.widgetpreview import WidgetPreview diff --git a/orangecontrib/text/widgets/tests/test_owscoredocuments.py b/orangecontrib/text/widgets/tests/test_owscoredocuments.py index 69b31e090..d9ccee032 100644 --- a/orangecontrib/text/widgets/tests/test_owscoredocuments.py +++ b/orangecontrib/text/widgets/tests/test_owscoredocuments.py @@ -1,19 +1,21 @@ import unittest from math import isclose -from typing import List, Union +from typing import List from unittest.mock import patch import numpy as np -from AnyQt.QtCore import Qt +from AnyQt.QtCore import QItemSelectionModel, Qt +from Orange.data import ContinuousVariable, Domain, StringVariable, Table +from Orange.misc.collections import natural_sorted from Orange.widgets.tests.base import WidgetTest -from Orange.data import Table, StringVariable, Domain, ContinuousVariable from Orange.widgets.tests.utils import simulate -from Orange.misc.collections import natural_sorted -from orangecontrib.text import Corpus -from orangecontrib.text import preprocess +from orangecontrib.text import Corpus, preprocess from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder -from orangecontrib.text.widgets.owscoredocuments import OWScoreDocuments +from orangecontrib.text.widgets.owscoredocuments import ( + OWScoreDocuments, + SelectionMethods, +) def embedding_mock(_, corpus, __): @@ -59,9 +61,7 @@ def setUp(self) -> None: def test_set_data(self): self.send_signal(self.widget.Inputs.corpus, self.corpus) - self.assertEqual( - [x[0] for x in self.widget.model], self.corpus.titles.tolist() - ) + self.assertEqual([x[0] for x in self.widget.model], self.corpus.titles.tolist()) self.assertTrue(self.widget.Warning.missing_words.is_shown()) self.send_signal(self.widget.Inputs.words, self.words) @@ -71,13 +71,10 @@ def test_set_data(self): self.assertTrue(all(len(x) == 2 for x in self.widget.model)) output = self.get_output(self.widget.Outputs.corpus) - self.assertTupleEqual( - output.domain.variables, self.corpus.domain.variables - ) - self.assertTupleEqual( - output.domain.metas[:-1], self.corpus.domain.metas - ) - self.assertEqual(str(output.domain.metas[-1]), "Word count") + self.assertTupleEqual(output.domain.variables, self.corpus.domain.variables) + self.assertTupleEqual(output.domain.metas[:1], self.corpus.domain.metas) + self.assertEqual(str(output.domain.metas[1]), "Word count") + self.assertEqual(str(output.domain.metas[2]), "Selected") self.assertEqual(len(output), len(self.corpus)) def test_corpus_not_normalized(self): @@ -101,12 +98,8 @@ def test_guess_word_attribute(self): w = StringVariable("Words") w.attributes["type"] = "words" w1 = StringVariable("Words 1") - words = np.array(["house", "doctor", "boy", "way", "Rum"]).reshape( - (-1, 1) - ) - words1 = np.array(["house", "doctor1", "boy", "way", "Rum"]).reshape( - (-1, 1) - ) + words = np.array(["house", "doctor", "boy", "way", "Rum"]).reshape((-1, 1)) + words1 = np.array(["house", "doctor1", "boy", "way", "Rum"]).reshape((-1, 1)) # guess by attribute type self.words = Table( @@ -128,9 +121,7 @@ def test_guess_word_attribute(self): # guess by length w2 = StringVariable("Words 2") - words2 = np.array(["house 1", "doctor 1", "boy", "way", "Rum"]).reshape( - (-1, 1) - ) + words2 = np.array(["house 1", "doctor 1", "boy", "way", "Rum"]).reshape((-1, 1)) self.words = Table( Domain([], metas=[w2, w1]), np.empty((len(words), 0)), @@ -183,7 +174,7 @@ def test_change_scorer(self): @staticmethod def create_corpus(texts: List[str]) -> Corpus: - """ Create sample corpus with texts passed """ + """Create sample corpus with texts passed""" text_var = StringVariable("Text") domain = Domain([], metas=[text_var]) c = Corpus( @@ -235,9 +226,7 @@ def test_word_appearance(self): self.widget.controls.word_frequency.click() self.widget.controls.word_appearance.click() self.wait_until_finished() - self.assertListEqual( - [x[1] for x in self.widget.model], [2 / 3, 2 / 3, 1] - ) + self.assertListEqual([x[1] for x in self.widget.model], [2 / 3, 2 / 3, 1]) cb_aggregation = self.widget.controls.aggregation simulate.combobox_activate_item(cb_aggregation, "Max") @@ -336,6 +325,49 @@ def test_sort_setting(self): self.assertTupleEqual((1, Qt.DescendingOrder), current_sorting) self.assertListEqual(sorted(data, reverse=True), data) + def test_selection_none(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.widget._sel_method_buttons.button(SelectionMethods.NONE).click() + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertIsNone(output) + + def tests_selection_all(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.widget._sel_method_buttons.button(SelectionMethods.ALL).click() + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertEqual(len(self.corpus), len(output)) + + def test_selection_manual(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.widget._sel_method_buttons.button(SelectionMethods.MANUAL).click() + + mode = QItemSelectionModel.Rows | QItemSelectionModel.Select + view = self.widget.view + view.clearSelection() + model = view.model() + view.selectionModel().select(model.index(2, 0), mode) + view.selectionModel().select(model.index(3, 0), mode) + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual(["Document 3", "Document 4"], output.titles.tolist()) + + def test_selection_n_best(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.widget._sel_method_buttons.button(SelectionMethods.N_BEST).click() + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual( + [f"Document {i}" for i in range(1, 4)], output.titles.tolist() + ) + + self.widget.controls.n_selected.setValue(5) + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual( + [f"Document {i}" for i in range(1, 6)], output.titles.tolist() + ) + if __name__ == "__main__": unittest.main()