Skip to content

Commit

Permalink
Score documents: Selection
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Aug 18, 2021
1 parent 96c295c commit 8186e3a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 21 deletions.
57 changes: 39 additions & 18 deletions orangecontrib/text/widgets/owscoredocuments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Callable, Tuple, Union

import numpy as np
from Orange.widgets.utils.annotated_data import create_annotated_table
from pandas import isnull
from Orange.data import (
Table,
Expand All @@ -19,7 +20,7 @@
from orangewidget.settings import Setting
from Orange.widgets.utils.itemmodels import PyTableModel, TableModel
from AnyQt.QtWidgets import QTableView, QLineEdit, QHeaderView
from AnyQt.QtCore import Qt, QSortFilterProxyModel
from AnyQt.QtCore import Qt, QSortFilterProxyModel, QItemSelection
from sklearn.metrics.pairwise import cosine_similarity

from orangecontrib.text import Corpus
Expand Down Expand Up @@ -189,7 +190,8 @@ def __init__(self):
super().__init__(
sortingEnabled=True,
editTriggers=QTableView.NoEditTriggers,
selectionMode=QTableView.NoSelection,
selectionMode=QTableView.ExtendedSelection,
selectionBehavior=QTableView.SelectRows,
cornerButtonEnabled=False,
)
self.setItemDelegate(gui.ColoredBarItemDelegate(self))
Expand Down Expand Up @@ -273,6 +275,7 @@ class Inputs:
words = Input("Words", Table)

class Outputs:
selected_documents = Output("Selected documents", Corpus, default=True)
corpus = Output("Corpus", Corpus)

class Warning(OWWidget.Warning):
Expand Down Expand Up @@ -347,6 +350,7 @@ def _setup_main_area(self) -> None:
proxy_model.setFilterCaseSensitivity(False)
view.setModel(proxy_model)
view.model().setSourceModel(self.model)
self.view.selectionModel().selectionChanged.connect(self._send_output)

def __on_filter_changed(self) -> None:
model = self.view.model()
Expand Down Expand Up @@ -426,50 +430,59 @@ def _gather_scores(self) -> Tuple[np.ndarray, List[str]]:
labels = [SCORING_METHODS[m][0] for m in methods]
return scores, labels

def _send_output(self, scores: np.ndarray, labels: List[str]) -> None:
def _send_output(self) -> None:
"""
Create corpus with scores and output it
"""
if self.corpus is None:
self.Outputs.corpus.send(None)
self.Outputs.selected_documents.send(None)
return

scores, labels = self._gather_scores()
if labels:
d = self.corpus.domain
domain = Domain(
d.attributes,
d.class_var,
metas=d.metas + tuple(ContinuousVariable(l) for l in labels),
)
corpus = Corpus(
out_corpus = Corpus(
domain,
self.corpus.X,
self.corpus.Y,
np.hstack([self.corpus.metas, scores]),
)
Corpus.retain_preprocessing(self.corpus, corpus)
self.Outputs.corpus.send(corpus)
elif self.corpus is not None:
self.Outputs.corpus.send(self.corpus)
Corpus.retain_preprocessing(self.corpus, out_corpus)
else:
self.Outputs.corpus.send(None)
out_corpus = self.corpus

def _fill_table(self, scores: np.ndarray, labels: List[str]) -> None:
selected_indices = self.get_selected_indices()
self.Outputs.corpus.send(create_annotated_table(out_corpus, selected_indices))
self.Outputs.selected_documents.send(out_corpus[selected_indices] if selected_indices else None)

def _fill_table(self) -> None:
"""
Fill the table in the widget with scores and document names
"""
if self.corpus is None:
self.model.clear()
return
scores, labels = self._gather_scores()
labels = ["Document"] + labels
titles = self.corpus.titles.tolist()

selected_indices = [i.row() for i in self.view.selectionModel().selectedRows()]
self.model.wrap([[c] + s for c, s in zip(titles, scores.tolist())])
self.model.setHorizontalHeaderLabels(labels)
self.view.update_column_widths()

self.view.horizontalHeader().setSortIndicator(*self.sort_column_order)
self._select_rows(selected_indices)

def _fill_and_output(self) -> None:
""" Fill the table in the widget and send the output """
scores, labels = self._gather_scores()
self._fill_table(scores, labels)
self._send_output(scores, labels)
self._fill_table()
self._send_output()

def _clear_and_run(self) -> None:
""" Clear cached scores and commit """
Expand Down Expand Up @@ -512,14 +525,12 @@ def commit(self) -> None:
self._fill_and_output()

def on_done(self, _: None) -> None:
scores, labels = self._gather_scores()
self._send_output(scores, labels)
self._send_output()

def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None:
sc_method, aggregation, scores = result
self.scores[(sc_method, aggregation)] = scores
scores, labels = self._gather_scores()
self._fill_table(scores, labels)
self._fill_table()

def on_exception(self, ex: Exception) -> None:
self.Error.unknown_err(ex)
Expand Down Expand Up @@ -554,6 +565,16 @@ def _is_corpus_normalized(corpus: Corpus) -> bool:
for pp in corpus.used_preprocessor.preprocessors
)

def get_selected_indices(self):
selected_rows = self.view.selectionModel().selectedRows()
return [self.view.model().mapToSource(r).row() for r in selected_rows]

def _select_rows(self, selected_indices):
self.view.setSelectionMode(QTableView.MultiSelection)
for row in selected_indices:
self.view.selectRow(row)
self.view.setSelectionMode(QTableView.ExtendedSelection)


if __name__ == "__main__":
from orangewidget.utils.widgetpreview import WidgetPreview
Expand Down
57 changes: 54 additions & 3 deletions orangecontrib/text/widgets/tests/test_owscoredocuments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest.mock import patch

import numpy as np
from AnyQt.QtCore import Qt
from AnyQt.QtCore import Qt, QItemSelectionModel
from Orange.widgets.tests.base import WidgetTest
from Orange.data import Table, StringVariable, Domain, ContinuousVariable
from Orange.widgets.tests.utils import simulate
Expand Down Expand Up @@ -75,9 +75,10 @@ def test_set_data(self):
output.domain.variables, self.corpus.domain.variables
)
self.assertTupleEqual(
output.domain.metas[:-1], self.corpus.domain.metas
output.domain.metas[:1], self.corpus.domain.metas
)
self.assertEqual(str(output.domain.metas[-1]), "Word count")
self.assertEqual(str(output.domain.metas[1]), "Word count")
self.assertEqual(str(output.domain.metas[2]), "Selected")
self.assertEqual(len(output), len(self.corpus))

def test_corpus_not_normalized(self):
Expand Down Expand Up @@ -336,6 +337,56 @@ def test_sort_setting(self):
self.assertTupleEqual((1, Qt.DescendingOrder), current_sorting)
self.assertListEqual(sorted(data, reverse=True), data)

def test_selection(self):
self.send_signal(self.widget.Inputs.corpus, self.corpus)

mode = QItemSelectionModel.Rows | QItemSelectionModel.Select
view = self.widget.view
model = view.model()
view.selectionModel().select(model.index(2, 0), mode)
view.selectionModel().select(model.index(3, 0), mode)

output = self.get_output(self.widget.Outputs.selected_documents)
self.assertListEqual(["Document 3", "Document 4"], output.titles.tolist())

# change table order and check if correct selection at the output
view.sortByColumn(0, Qt.DescendingOrder)
view.clearSelection()
view.selectionModel().select(model.index(0, 0), mode)
view.selectionModel().select(model.index(1, 0), mode)

output = self.get_output(self.widget.Outputs.selected_documents)
self.assertListEqual(["Document 140", "Document 139"], output.titles.tolist())
view.clearSelection()

def test_selection_socres(self):
# repeat with scores
self.send_signal(self.widget.Inputs.corpus, self.corpus)
words = self.create_words_table(["house", "doctor", "boy", "way", "Rum"])
self.send_signal(self.widget.Inputs.words, words)
self.wait_until_finished()

mode = QItemSelectionModel.Rows | QItemSelectionModel.Select
view = self.widget.view
model = view.model()

model = self.widget.model
view = self.widget.view
view.selectionModel().select(model.index(2, 0), mode)
view.selectionModel().select(model.index(3, 0), mode)

output = self.get_output(self.widget.Outputs.selected_documents)
self.assertListEqual(["Document 3", "Document 4"], output.titles.tolist())

# change table order and check if correct selection at the output
view.sortByColumn(0, Qt.DescendingOrder)
view.clearSelection()
view.selectionModel().select(model.index(0, 0), mode)
view.selectionModel().select(model.index(1, 0), mode)

output = self.get_output(self.widget.Outputs.selected_documents)
self.assertListEqual(["Document 140", "Document 139"], output.titles.tolist())


if __name__ == "__main__":
unittest.main()

0 comments on commit 8186e3a

Please sign in to comment.