diff --git a/orangecontrib/text/corpus.py b/orangecontrib/text/corpus.py index 3ae981b61..70c38567c 100644 --- a/orangecontrib/text/corpus.py +++ b/orangecontrib/text/corpus.py @@ -85,7 +85,7 @@ def __init__(self, domain=None, X=None, Y=None, metas=None, W=None, self._ngrams_corpus = None self.ngram_range = (1, 1) self.attributes = {} - self.pos_tags = None + self._pos_tags = None from orangecontrib.text.preprocess import PreprocessorList self.__used_preprocessor = PreprocessorList([]) # required for compute values self._titles: Optional[np.ndarray] = None @@ -448,6 +448,20 @@ def dictionary(self): return self._base_tokens()[1] return self._dictionary + @property + def pos_tags(self): + """ + np.ndarray: A list of lists containing POS tags. If there are no + POS tags available, return None. + """ + if self._pos_tags is None: + return None + return np.array(self._pos_tags, dtype=object) + + @pos_tags.setter + def pos_tags(self, pos_tags): + self._pos_tags = pos_tags + def ngrams_iterator(self, join_with=' ', include_postags=False): if self.pos_tags is None: include_postags = False diff --git a/orangecontrib/text/preprocess/filter.py b/orangecontrib/text/preprocess/filter.py index 0aeb371e4..12ad9d51a 100644 --- a/orangecontrib/text/preprocess/filter.py +++ b/orangecontrib/text/preprocess/filter.py @@ -2,6 +2,7 @@ import os import re +import numpy as np from gensim import corpora from nltk.corpus import stopwords @@ -13,7 +14,8 @@ from orangecontrib.text.preprocess import TokenizedPreprocessor __all__ = ['BaseTokenFilter', 'StopwordsFilter', 'LexiconFilter', - 'RegexpFilter', 'FrequencyFilter', 'MostFrequentTokensFilter'] + 'RegexpFilter', 'FrequencyFilter', 'MostFrequentTokensFilter', + 'PosTagFilter'] class BaseTokenFilter(TokenizedPreprocessor): @@ -195,3 +197,46 @@ def _fit(self, corpus: Corpus): self._dictionary = corpora.Dictionary(corpus.tokens) self._dictionary.filter_extremes(0, 1, self._keep_n) self._lexicon = set(self._dictionary.token2id.keys()) + + +class PosTagFilter(BaseTokenFilter): + """Keep selected POS tags.""" + name = 'POS tags' + + def __init__(self, tags=None): + self._tags = set(i.strip().upper() for i in tags.split(",")) + + def __call__(self, corpus: Corpus, callback: Callable = None) -> Corpus: + if callback is None: + callback = dummy_callback + corpus = super().__call__(corpus, wrap_callback(callback, end=0.2)) + return self._filter_tokens(corpus, wrap_callback(callback, start=0.2)) + + @staticmethod + def validate_tags(tags): + # should we keep a dict of existing POS tags and compare them with + # input? + return len(tags.split(",")) > 0 + + def _filter_tokens(self, corpus: Corpus, callback: Callable) -> Corpus: + if corpus.pos_tags is None: + return corpus + callback(0, "Filtering...") + filtered_tags = [] + filtered_tokens = [] + for tags, tokens in zip(corpus.pos_tags, corpus.tokens): + tmp_tags = [] + tmp_tokens = [] + for tag, token in zip(tags, tokens): + # should we consider partial matches, i.e. "NN" for "NNS"? + if tag in self._tags: + tmp_tags.append(tag) + tmp_tokens.append(token) + filtered_tags.append(tmp_tags) + filtered_tokens.append(tmp_tokens) + corpus.store_tokens(filtered_tokens) + corpus.pos_tags = filtered_tags + return corpus + + def _check(self, token: str) -> bool: + pass diff --git a/orangecontrib/text/tests/test_preprocess.py b/orangecontrib/text/tests/test_preprocess.py index d7f8f7235..2c0ecd0ee 100644 --- a/orangecontrib/text/tests/test_preprocess.py +++ b/orangecontrib/text/tests/test_preprocess.py @@ -435,6 +435,18 @@ def test_regex_filter(self): self.assertEqual(filtered.tokens[0], [' http']) self.assertEqual(len(filtered.used_preprocessor.preprocessors), 2) + def test_pos_filter(self): + pos_filter = preprocess.PosTagFilter("NN") + pp_list = [preprocess.WordPunctTokenizer(), + tag.AveragedPerceptronTagger()] + corpus = self.corpus + for pp in pp_list: + corpus = pp(corpus) + filtered = pos_filter(corpus) + self.assertTrue(len(filtered.pos_tags)) + self.assertEqual(len(filtered.pos_tags[0]), 5) + self.assertEqual(len(filtered.tokens[0]), 5) + def test_can_deepcopy(self): copied = copy.deepcopy(self.regexp) self.corpus.metas[0, 0] = 'foo bar' diff --git a/orangecontrib/text/widgets/owpreprocess.py b/orangecontrib/text/widgets/owpreprocess.py index 1cddc1826..d88824b65 100644 --- a/orangecontrib/text/widgets/owpreprocess.py +++ b/orangecontrib/text/widgets/owpreprocess.py @@ -501,13 +501,15 @@ def __repr__(self): class FilteringModule(MultipleMethodModule): - Stopwords, Lexicon, Regexp, DocFreq, DummyDocFreq, MostFreq = range(6) + Stopwords, Lexicon, Regexp, DocFreq, DummyDocFreq, MostFreq, PosTag = \ + range(7) Methods = {Stopwords: StopwordsFilter, Lexicon: LexiconFilter, Regexp: RegexpFilter, DocFreq: FrequencyFilter, DummyDocFreq: FrequencyFilter, - MostFreq: MostFrequentTokensFilter} + MostFreq: MostFrequentTokensFilter, + PosTag: PosTagFilter} DEFAULT_METHODS = [Stopwords] DEFAULT_LANG = "English" DEFAULT_NONE = None @@ -517,6 +519,7 @@ class FilteringModule(MultipleMethodModule): DEFAULT_REL_START, DEFAULT_REL_END, REL_MIN, REL_MAX = 0.1, 0.9, 0, 1 DEFAULT_ABS_START, DEFAULT_ABS_END, ABS_MIN, ABS_MAX = 1, 10, 0, 10000 DEFAULT_N_TOKEN = 100 + DEFAULT_POS_TAGS = "NOUN,VERB" def __init__(self, parent=None, **kwargs): super().__init__(parent, **kwargs) @@ -530,6 +533,7 @@ def __init__(self, parent=None, **kwargs): self.__abs_freq_st = self.DEFAULT_ABS_START self.__abs_freq_en = self.DEFAULT_ABS_END self.__n_token = self.DEFAULT_N_TOKEN + self.__pos_tag = self.DEFAULT_POS_TAGS self.__invalidated = False self.__combo = ComboBox( @@ -574,6 +578,10 @@ def __init__(self, parent=None, **kwargs): self.__spin_n.editingFinished.connect(self.__spin_n_edited) self.__spin_n.valueChanged.connect(self.changed) + validator = PosTagFilter.validate_tags + self.__pos_edit = ValidatedLineEdit(self.__pos_tag, validator) + self.__pos_edit.editingFinished.connect(self.__pos_edit_finished) + self.layout().addWidget(self.__combo, self.Stopwords, 1) self.layout().addWidget(self.__sw_loader.file_combo, self.Stopwords, 2, 1, 2) @@ -595,6 +603,7 @@ def __init__(self, parent=None, **kwargs): title = self.layout().itemAtPosition(self.DummyDocFreq, 0).widget() title.hide() self.layout().addWidget(self.__spin_n, self.MostFreq, 1) + self.layout().addWidget(self.__pos_edit, self.PosTag, 1, 1, 5) self.layout().setColumnStretch(3, 1) def __sw_loader_activated(self): @@ -626,6 +635,13 @@ def __edit_finished(self): if self.Regexp in self.methods: self.edited.emit() + def __pos_edit_finished(self): + tags = self.__pos_edit.text() + if self.__pos_tag != tags: + self.__set_tags(tags) + if self.PosTag in self.methods: + self.edited.emit() + def __freq_group_clicked(self): i = self.__freq_group.checkedId() if self.__freq_type != i: @@ -666,6 +682,7 @@ def setParameters(self, params: Dict): params.get("abs_end", self.DEFAULT_ABS_END) ) self.__set_n_tokens(params.get("n_tokens", self.DEFAULT_N_TOKEN)) + self.__set_tags(params.get("pos_tags", self.DEFAULT_POS_TAGS)) self.__invalidated = False def __set_language(self, language: str): @@ -694,6 +711,12 @@ def __set_pattern(self, pattern: str): self.__edit.setText(pattern) self.changed.emit() + def __set_tags(self, tags: str): + if self.__pos_tag != tags: + self.__pos_tag = tags + self.__pos_edit.setText(tags) + self.changed.emit() + def __set_freq_type(self, freq_type: int): if self.__freq_type != freq_type: self.__freq_type = freq_type @@ -750,6 +773,7 @@ def parameters(self) -> Dict: "abs_start": self.__abs_freq_st, "abs_end": self.__abs_freq_en, "n_tokens": self.__n_token, + "pos_tags": self.__pos_tag, "invalidated": self.__invalidated}) return params @@ -782,6 +806,9 @@ def map_none(s): if FilteringModule.MostFreq in methods: n = params.get("n_tokens", FilteringModule.DEFAULT_N_TOKEN) filters.append(MostFrequentTokensFilter(keep_n=n)) + if FilteringModule.PosTag in methods: + tags = params.get("pos_tags", FilteringModule.DEFAULT_POS_TAGS) + filters.append(PosTagFilter(tags=tags)) return filters def __repr__(self): @@ -801,6 +828,8 @@ def __repr__(self): append = f"[{self.__abs_freq_st}, {self.__abs_freq_en}]" elif method == self.MostFreq: append = f"{self.__n_token}" + elif method == self.PosTag: + append = f"{self.__pos_tag}" texts.append(f"{self.Methods[method].name} ({append})") return ", ".join(texts) diff --git a/orangecontrib/text/widgets/tests/test_owpreprocess.py b/orangecontrib/text/widgets/tests/test_owpreprocess.py index a8248a919..c330d585e 100644 --- a/orangecontrib/text/widgets/tests/test_owpreprocess.py +++ b/orangecontrib/text/widgets/tests/test_owpreprocess.py @@ -485,7 +485,8 @@ def test_parameters(self): "freq_type": 0, "rel_start": 0.1, "rel_end": 0.9, "abs_start": 1, "abs_end": 10, - "n_tokens": 100, "invalidated": False} + "n_tokens": 100, "pos_tags": "NOUN,VERB", + "invalidated": False} self.assertDictEqual(self.editor.parameters(), params) def test_set_parameters(self): @@ -499,7 +500,8 @@ def test_set_parameters(self): "freq_type": 1, "rel_start": 0.2, "rel_end": 0.7, "abs_start": 2, "abs_end": 15, - "n_tokens": 10, "invalidated": False} + "n_tokens": 10, "pos_tags": "JJ", + "invalidated": False} self.editor.setParameters(params) self.assertDictEqual(self.editor.parameters(), params)