Skip to content

Commit

Permalink
Score documents: Selection
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Sep 6, 2021
1 parent 7a5f5fe commit b521d15
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 28 deletions.
214 changes: 190 additions & 24 deletions orangecontrib/text/widgets/owscoredocuments.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
import re
from collections import Counter
from contextlib import contextmanager
from inspect import signature
from typing import Callable, List, Tuple, Union

import numpy as np
from AnyQt.QtCore import QSortFilterProxyModel, Qt
from AnyQt.QtWidgets import QHeaderView, QLineEdit, QTableView
from AnyQt.QtCore import (
QItemSelection,
QItemSelectionModel,
QSortFilterProxyModel,
Qt,
Signal,
)
from AnyQt.QtWidgets import (
QButtonGroup,
QGridLayout,
QHeaderView,
QLineEdit,
QRadioButton,
QTableView,
)
from pandas import isnull
from sklearn.metrics.pairwise import cosine_similarity

# todo: uncomment when minimum version of Orange is 3.29.2
# from orangecanvas.gui.utils import disconnected
from orangewidget import gui
from Orange.data import ContinuousVariable, Domain, StringVariable, Table
from Orange.util import wrap_callback
from Orange.widgets.settings import ContextSetting, PerfectDomainContextHandler, Setting
from Orange.widgets.utils.annotated_data import create_annotated_table
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
from Orange.widgets.utils.itemmodels import PyTableModel, TableModel
from Orange.widgets.widget import Input, Msg, Output, OWWidget
from orangewidget import gui
from orangewidget.settings import Setting
from pandas import isnull
from sklearn.metrics.pairwise import cosine_similarity

from orangecontrib.text import Corpus
from orangecontrib.text.preprocess import BaseNormalizer, BaseTransformer
Expand All @@ -24,6 +42,16 @@
)


# todo: remove when minimum version of Orange is 3.29.2
@contextmanager
def disconnected(signal, slot, type=Qt.UniqueConnection):
signal.disconnect(slot)
try:
yield
finally:
signal.connect(slot, type)


def _word_frequency(corpus: Corpus, words: List[str], callback: Callable) -> np.ndarray:
res = []
tokens = corpus.tokens
Expand Down Expand Up @@ -183,12 +211,20 @@ def callback(i: float) -> None:
state.set_partial_result((sm, aggregation, scs))


class SelectionMethods:
NONE, ALL, MANUAL, N_BEST = range(4)
ITEMS = "None", "All", "Manual", "Top documents"


class ScoreDocumentsTableView(QTableView):
pressedAny = Signal()

def __init__(self):
super().__init__(
sortingEnabled=True,
editTriggers=QTableView.NoEditTriggers,
selectionMode=QTableView.NoSelection,
selectionMode=QTableView.ExtendedSelection,
selectionBehavior=QTableView.SelectRows,
cornerButtonEnabled=False,
)
self.setItemDelegate(gui.ColoredBarItemDelegate(self))
Expand All @@ -215,6 +251,10 @@ def update_column_widths(self) -> None:
# document title column is one that stretch
header.setSectionResizeMode(0, QHeaderView.Stretch)

def mousePressEvent(self, event):
super().mousePressEvent(event)
self.pressedAny.emit()


class ScoreDocumentsProxyModel(QSortFilterProxyModel):
@staticmethod
Expand Down Expand Up @@ -253,23 +293,31 @@ class OWScoreDocuments(OWWidget, ConcurrentWidgetMixin):
icon = "icons/ScoreDocuments.svg"
priority = 500

buttons_area_orientation = Qt.Vertical

# default order - table sorted in input order
DEFAULT_SORTING = (-1, Qt.AscendingOrder)

settingsHandler = PerfectDomainContextHandler()
auto_commit: bool = Setting(True)
aggregation: int = Setting(0)

word_frequency: bool = Setting(True)
word_appearance: bool = Setting(False)
embedding_similarity: bool = Setting(False)
embedding_language: int = Setting(0)

sort_column_order: Tuple[int, int] = Setting(DEFAULT_SORTING)
embedding_language = Setting(0)
selected_rows: List[int] = ContextSetting([], schema_only=True)
sel_method: int = ContextSetting(SelectionMethods.N_BEST)
n_selected: int = ContextSetting(3)

class Inputs:
corpus = Input("Corpus", Corpus)
words = Input("Words", Table)

class Outputs:
selected_documents = Output("Selected documents", Corpus, default=True)
corpus = Output("Corpus", Corpus)

class Warning(OWWidget.Warning):
Expand Down Expand Up @@ -322,6 +370,33 @@ def _setup_control_area(self) -> None:
)

gui.rubber(self.controlArea)

# select words box
box = gui.vBox(self.buttonsArea, "Select Documents")
grid = QGridLayout()
grid.setContentsMargins(0, 0, 0, 0)

self._sel_method_buttons = QButtonGroup()
for method, label in enumerate(SelectionMethods.ITEMS):
button = QRadioButton(label)
button.setChecked(method == self.sel_method)
grid.addWidget(button, method, 0)
self._sel_method_buttons.addButton(button, method)
self._sel_method_buttons.buttonClicked[int].connect(self.__set_selection_method)

spin = gui.spin(
box,
self,
"n_selected",
1,
999,
addToLayout=False,
callback=lambda: self.__set_selection_method(SelectionMethods.N_BEST),
)
grid.addWidget(spin, 3, 1)
box.layout().addLayout(grid)

# autocommit
gui.auto_send(self.buttonsArea, self, "auto_commit")

def _setup_main_area(self) -> None:
Expand All @@ -333,7 +408,11 @@ def _setup_main_area(self) -> None:
self.model = model = ScoreDocumentsTableModel(parent=self)
model.setHorizontalHeaderLabels(["Document"])

def select_manual():
self.__set_selection_method(SelectionMethods.MANUAL)

self.view = view = ScoreDocumentsTableView()
view.pressedAny.connect(select_manual)
self.mainArea.layout().addWidget(view)
# by default data are sorted in the Table order
header = self.view.horizontalHeader()
Expand All @@ -344,6 +423,7 @@ def _setup_main_area(self) -> None:
proxy_model.setFilterCaseSensitivity(False)
view.setModel(proxy_model)
view.model().setSourceModel(self.model)
self.view.selectionModel().selectionChanged.connect(self.__on_selection_change)

def __on_filter_changed(self) -> None:
model = self.view.model()
Expand All @@ -352,15 +432,39 @@ def __on_filter_changed(self) -> None:
def __on_horizontal_header_clicked(self, index: int):
header = self.view.horizontalHeader()
self.sort_column_order = (index, header.sortIndicatorOrder())
self._select_rows()
# when sorting change output table must consider the new order
# call explicitly since selection in table is not changed
if (
self.sel_method == SelectionMethods.MANUAL
and self.selected_rows
or self.sel_method == SelectionMethods.ALL
):
# retrieve selection in new order
self.selected_rows = self.get_selected_indices()
self._send_output()

def __on_selection_change(self):
self.selected_rows = self.get_selected_indices()
self._send_output()

def __set_selection_method(self, method: int):
self.sel_method = method
self._sel_method_buttons.button(method).setChecked(True)
self._select_rows()

@Inputs.corpus
def set_data(self, corpus: Corpus) -> None:
self.closeContext()
self.Warning.corpus_not_normalized.clear()
if corpus is not None:
self.Warning.missing_corpus.clear()
if not self._is_corpus_normalized(corpus):
self.Warning.corpus_not_normalized()
self.corpus = corpus
self.selected_rows = []
self.openContext(corpus)
self._sel_method_buttons.button(self.sel_method).setChecked(True)
self._clear_and_run()

@staticmethod
Expand Down Expand Up @@ -416,50 +520,71 @@ def _gather_scores(self) -> Tuple[np.ndarray, List[str]]:
labels = [SCORING_METHODS[m][0] for m in methods]
return scores, labels

def _send_output(self, scores: np.ndarray, labels: List[str]) -> None:
def _send_output(self) -> None:
"""
Create corpus with scores and output it
"""
if self.corpus is None:
self.Outputs.corpus.send(None)
self.Outputs.selected_documents.send(None)
return

scores, labels = self._gather_scores()
if labels:
d = self.corpus.domain
domain = Domain(
d.attributes,
d.class_var,
metas=d.metas + tuple(ContinuousVariable(l) for l in labels),
)
corpus = Corpus(
out_corpus = Corpus(
domain,
self.corpus.X,
self.corpus.Y,
np.hstack([self.corpus.metas, scores]),
)
Corpus.retain_preprocessing(self.corpus, corpus)
self.Outputs.corpus.send(corpus)
elif self.corpus is not None:
self.Outputs.corpus.send(self.corpus)
Corpus.retain_preprocessing(self.corpus, out_corpus)
else:
self.Outputs.corpus.send(None)
out_corpus = self.corpus

def _fill_table(self, scores: np.ndarray, labels: List[str]) -> None:
self.Outputs.corpus.send(create_annotated_table(out_corpus, self.selected_rows))
self.Outputs.selected_documents.send(
out_corpus[self.selected_rows] if self.selected_rows else None
)

def _fill_table(self) -> None:
"""
Fill the table in the widget with scores and document names
"""
if self.corpus is None:
self.model.clear()
return
scores, labels = self._gather_scores()
labels = ["Document"] + labels
titles = self.corpus.titles.tolist()

# clearing selection and sorting to prevent SEGFAULT on model.wrap
self.view.horizontalHeader().setSortIndicator(-1, Qt.AscendingOrder)
with disconnected(
self.view.selectionModel().selectionChanged, self.__on_selection_change
):
self.view.clearSelection()

self.model.wrap([[c] + s for c, s in zip(titles, scores.tolist())])
self.model.setHorizontalHeaderLabels(labels)
self.view.update_column_widths()
if self.model.columnCount() > self.sort_column_order[0]:
# if not enough columns do not apply sorting from settings since
# sorting can besaved for score column while scores are still computing
# tables is filled before scores are computed with document names
self.view.horizontalHeader().setSortIndicator(*self.sort_column_order)

self.view.horizontalHeader().setSortIndicator(*self.sort_column_order)
self._select_rows()

def _fill_and_output(self) -> None:
"""Fill the table in the widget and send the output"""
scores, labels = self._gather_scores()
self._fill_table(scores, labels)
self._send_output(scores, labels)
self._fill_table()
self._send_output()

def _clear_and_run(self) -> None:
"""Clear cached scores and commit"""
Expand Down Expand Up @@ -500,14 +625,12 @@ def commit(self) -> None:
self._fill_and_output()

def on_done(self, _: None) -> None:
scores, labels = self._gather_scores()
self._send_output(scores, labels)
self._send_output()

def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None:
sc_method, aggregation, scores = result
self.scores[(sc_method, aggregation)] = scores
scores, labels = self._gather_scores()
self._fill_table(scores, labels)
self._fill_table()

def on_exception(self, ex: Exception) -> None:
self.Error.custom_err(ex)
Expand Down Expand Up @@ -543,6 +666,49 @@ def _is_corpus_normalized(corpus: Corpus) -> bool:
for pp in corpus.used_preprocessor.preprocessors
)

def get_selected_indices(self) -> List[int]:
# get indices in table's order - that the selected output table have same order
selected_rows = sorted(
self.view.selectionModel().selectedRows(), key=lambda idx: idx.row()
)
return [self.view.model().mapToSource(r).row() for r in selected_rows]

def _select_rows(self):
proxy_model = self.view.model()
n_rows, n_columns = proxy_model.rowCount(), proxy_model.columnCount()
if self.sel_method == SelectionMethods.NONE:
selection = QItemSelection()
elif self.sel_method == SelectionMethods.ALL:
selection = QItemSelection(
proxy_model.index(0, 0), proxy_model.index(n_rows - 1, n_columns - 1)
)
elif self.sel_method == SelectionMethods.MANUAL:
selection = QItemSelection()
new_sel = []
for row in self.selected_rows:
if row < n_rows:
new_sel.append(row)
_selection = QItemSelection(
self.model.index(row, 0), self.model.index(row, n_columns - 1)
)
selection.merge(
proxy_model.mapSelectionFromSource(_selection),
QItemSelectionModel.Select,
)
# selected rows must be updated when the same dataset with less rows
# appear at the input - it is not handled by selectionChanged
# in cases when all selected rows missing in new table
self.selected_rows = new_sel
elif self.sel_method == SelectionMethods.N_BEST:
n_sel = min(self.n_selected, n_rows)
selection = QItemSelection(
proxy_model.index(0, 0), proxy_model.index(n_sel - 1, n_columns - 1)
)
else:
raise NotImplementedError

self.view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)


if __name__ == "__main__":
from orangewidget.utils.widgetpreview import WidgetPreview
Expand Down
Loading

0 comments on commit b521d15

Please sign in to comment.