Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OWPreprocess Text: add option to filter on POS tags #679

Merged
merged 3 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion orangecontrib/text/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, domain=None, X=None, Y=None, metas=None, W=None,
self._ngrams_corpus = None
self.ngram_range = (1, 1)
self.attributes = {}
self.pos_tags = None
self._pos_tags = None
from orangecontrib.text.preprocess import PreprocessorList
self.__used_preprocessor = PreprocessorList([]) # required for compute values
self._titles: Optional[np.ndarray] = None
Expand Down Expand Up @@ -448,6 +448,20 @@ def dictionary(self):
return self._base_tokens()[1]
return self._dictionary

@property
def pos_tags(self):
"""
np.ndarray: A list of lists containing POS tags. If there are no
POS tags available, return None.
"""
if self._pos_tags is None:
return None
return np.array(self._pos_tags, dtype=object)

@pos_tags.setter
def pos_tags(self, pos_tags):
self._pos_tags = pos_tags

def ngrams_iterator(self, join_with=' ', include_postags=False):
if self.pos_tags is None:
include_postags = False
Expand Down
47 changes: 46 additions & 1 deletion orangecontrib/text/preprocess/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import re

import numpy as np
from gensim import corpora
from nltk.corpus import stopwords

Expand All @@ -13,7 +14,8 @@
from orangecontrib.text.preprocess import TokenizedPreprocessor

__all__ = ['BaseTokenFilter', 'StopwordsFilter', 'LexiconFilter',
'RegexpFilter', 'FrequencyFilter', 'MostFrequentTokensFilter']
'RegexpFilter', 'FrequencyFilter', 'MostFrequentTokensFilter',
'PosTagFilter']


class BaseTokenFilter(TokenizedPreprocessor):
Expand Down Expand Up @@ -195,3 +197,46 @@ def _fit(self, corpus: Corpus):
self._dictionary = corpora.Dictionary(corpus.tokens)
self._dictionary.filter_extremes(0, 1, self._keep_n)
self._lexicon = set(self._dictionary.token2id.keys())


class PosTagFilter(BaseTokenFilter):
"""Keep selected POS tags."""
name = 'POS tags'

def __init__(self, tags=None):
self._tags = set(i.strip().upper() for i in tags.split(","))

def __call__(self, corpus: Corpus, callback: Callable = None) -> Corpus:
if callback is None:
callback = dummy_callback
corpus = super().__call__(corpus, wrap_callback(callback, end=0.2))
return self._filter_tokens(corpus, wrap_callback(callback, start=0.2))

@staticmethod
def validate_tags(tags):
# should we keep a dict of existing POS tags and compare them with
# input?
return len(tags.split(",")) > 0

def _filter_tokens(self, corpus: Corpus, callback: Callable) -> Corpus:
if corpus.pos_tags is None:
return corpus
callback(0, "Filtering...")
filtered_tags = []
filtered_tokens = []
for tags, tokens in zip(corpus.pos_tags, corpus.tokens):
tmp_tags = []
tmp_tokens = []
for tag, token in zip(tags, tokens):
# should we consider partial matches, i.e. "NN" for "NNS"?
if tag in self._tags:
tmp_tags.append(tag)
tmp_tokens.append(token)
filtered_tags.append(tmp_tags)
filtered_tokens.append(tmp_tokens)
corpus.store_tokens(filtered_tokens)
corpus.pos_tags = filtered_tags
return corpus

def _check(self, token: str) -> bool:
pass
12 changes: 12 additions & 0 deletions orangecontrib/text/tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,18 @@ def test_regex_filter(self):
self.assertEqual(filtered.tokens[0], [' http'])
self.assertEqual(len(filtered.used_preprocessor.preprocessors), 2)

def test_pos_filter(self):
pos_filter = preprocess.PosTagFilter("NN")
pp_list = [preprocess.WordPunctTokenizer(),
tag.AveragedPerceptronTagger()]
corpus = self.corpus
for pp in pp_list:
corpus = pp(corpus)
filtered = pos_filter(corpus)
self.assertTrue(len(filtered.pos_tags))
self.assertEqual(len(filtered.pos_tags[0]), 5)
self.assertEqual(len(filtered.tokens[0]), 5)

def test_can_deepcopy(self):
copied = copy.deepcopy(self.regexp)
self.corpus.metas[0, 0] = 'foo bar'
Expand Down
33 changes: 31 additions & 2 deletions orangecontrib/text/widgets/owpreprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,15 @@ def __repr__(self):


class FilteringModule(MultipleMethodModule):
Stopwords, Lexicon, Regexp, DocFreq, DummyDocFreq, MostFreq = range(6)
Stopwords, Lexicon, Regexp, DocFreq, DummyDocFreq, MostFreq, PosTag = \
range(7)
Methods = {Stopwords: StopwordsFilter,
Lexicon: LexiconFilter,
Regexp: RegexpFilter,
DocFreq: FrequencyFilter,
DummyDocFreq: FrequencyFilter,
MostFreq: MostFrequentTokensFilter}
MostFreq: MostFrequentTokensFilter,
PosTag: PosTagFilter}
DEFAULT_METHODS = [Stopwords]
DEFAULT_LANG = "English"
DEFAULT_NONE = None
Expand All @@ -517,6 +519,7 @@ class FilteringModule(MultipleMethodModule):
DEFAULT_REL_START, DEFAULT_REL_END, REL_MIN, REL_MAX = 0.1, 0.9, 0, 1
DEFAULT_ABS_START, DEFAULT_ABS_END, ABS_MIN, ABS_MAX = 1, 10, 0, 10000
DEFAULT_N_TOKEN = 100
DEFAULT_POS_TAGS = "NOUN,VERB"

def __init__(self, parent=None, **kwargs):
super().__init__(parent, **kwargs)
Expand All @@ -530,6 +533,7 @@ def __init__(self, parent=None, **kwargs):
self.__abs_freq_st = self.DEFAULT_ABS_START
self.__abs_freq_en = self.DEFAULT_ABS_END
self.__n_token = self.DEFAULT_N_TOKEN
self.__pos_tag = self.DEFAULT_POS_TAGS
self.__invalidated = False

self.__combo = ComboBox(
Expand Down Expand Up @@ -574,6 +578,10 @@ def __init__(self, parent=None, **kwargs):
self.__spin_n.editingFinished.connect(self.__spin_n_edited)
self.__spin_n.valueChanged.connect(self.changed)

validator = PosTagFilter.validate_tags
self.__pos_edit = ValidatedLineEdit(self.__pos_tag, validator)
self.__pos_edit.editingFinished.connect(self.__pos_edit_finished)

self.layout().addWidget(self.__combo, self.Stopwords, 1)
self.layout().addWidget(self.__sw_loader.file_combo,
self.Stopwords, 2, 1, 2)
Expand All @@ -595,6 +603,7 @@ def __init__(self, parent=None, **kwargs):
title = self.layout().itemAtPosition(self.DummyDocFreq, 0).widget()
title.hide()
self.layout().addWidget(self.__spin_n, self.MostFreq, 1)
self.layout().addWidget(self.__pos_edit, self.PosTag, 1, 1, 5)
self.layout().setColumnStretch(3, 1)

def __sw_loader_activated(self):
Expand Down Expand Up @@ -626,6 +635,13 @@ def __edit_finished(self):
if self.Regexp in self.methods:
self.edited.emit()

def __pos_edit_finished(self):
tags = self.__pos_edit.text()
if self.__pos_tag != tags:
self.__set_tags(tags)
if self.PosTag in self.methods:
self.edited.emit()

def __freq_group_clicked(self):
i = self.__freq_group.checkedId()
if self.__freq_type != i:
Expand Down Expand Up @@ -666,6 +682,7 @@ def setParameters(self, params: Dict):
params.get("abs_end", self.DEFAULT_ABS_END)
)
self.__set_n_tokens(params.get("n_tokens", self.DEFAULT_N_TOKEN))
self.__set_tags(params.get("pos_tags", self.DEFAULT_POS_TAGS))
self.__invalidated = False

def __set_language(self, language: str):
Expand Down Expand Up @@ -694,6 +711,12 @@ def __set_pattern(self, pattern: str):
self.__edit.setText(pattern)
self.changed.emit()

def __set_tags(self, tags: str):
if self.__pos_tag != tags:
self.__pos_tag = tags
self.__pos_edit.setText(tags)
self.changed.emit()

def __set_freq_type(self, freq_type: int):
if self.__freq_type != freq_type:
self.__freq_type = freq_type
Expand Down Expand Up @@ -750,6 +773,7 @@ def parameters(self) -> Dict:
"abs_start": self.__abs_freq_st,
"abs_end": self.__abs_freq_en,
"n_tokens": self.__n_token,
"pos_tags": self.__pos_tag,
"invalidated": self.__invalidated})
return params

Expand Down Expand Up @@ -782,6 +806,9 @@ def map_none(s):
if FilteringModule.MostFreq in methods:
n = params.get("n_tokens", FilteringModule.DEFAULT_N_TOKEN)
filters.append(MostFrequentTokensFilter(keep_n=n))
if FilteringModule.PosTag in methods:
tags = params.get("pos_tags", FilteringModule.DEFAULT_POS_TAGS)
filters.append(PosTagFilter(tags=tags))
return filters

def __repr__(self):
Expand All @@ -801,6 +828,8 @@ def __repr__(self):
append = f"[{self.__abs_freq_st}, {self.__abs_freq_en}]"
elif method == self.MostFreq:
append = f"{self.__n_token}"
elif method == self.PosTag:
append = f"{self.__pos_tag}"
texts.append(f"{self.Methods[method].name} ({append})")
return ", ".join(texts)

Expand Down
6 changes: 4 additions & 2 deletions orangecontrib/text/widgets/tests/test_owpreprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ def test_parameters(self):
"freq_type": 0,
"rel_start": 0.1, "rel_end": 0.9,
"abs_start": 1, "abs_end": 10,
"n_tokens": 100, "invalidated": False}
"n_tokens": 100, "pos_tags": "NOUN,VERB",
"invalidated": False}
self.assertDictEqual(self.editor.parameters(), params)

def test_set_parameters(self):
Expand All @@ -499,7 +500,8 @@ def test_set_parameters(self):
"freq_type": 1,
"rel_start": 0.2, "rel_end": 0.7,
"abs_start": 2, "abs_end": 15,
"n_tokens": 10, "invalidated": False}
"n_tokens": 10, "pos_tags": "JJ",
"invalidated": False}
self.editor.setParameters(params)
self.assertDictEqual(self.editor.parameters(), params)

Expand Down