diff --git a/docs/code/data_aug.rst b/docs/code/data_aug.rst index ecdbb5294..9bd925b32 100644 --- a/docs/code/data_aug.rst +++ b/docs/code/data_aug.rst @@ -159,6 +159,11 @@ Data Augmentation Ops .. autoclass:: forte.processors.data_augment.algorithms.eda_ops.RandomDeletionDataAugmentOp :members: +:hidden:`AbbreviationReplacementOp` +------------------------------------------ +.. autoclass:: forte.processors.data_augment.algorithms.abbreviation_replacement_op.AbbreviationReplacementOp + :members: + Data Augmentation Models ======================================== diff --git a/forte/processors/data_augment/algorithms/abbreviation_replacement_op.py b/forte/processors/data_augment/algorithms/abbreviation_replacement_op.py new file mode 100644 index 000000000..79bed2b9c --- /dev/null +++ b/forte/processors/data_augment/algorithms/abbreviation_replacement_op.py @@ -0,0 +1,106 @@ +# Copyright 2022 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import json +from typing import Tuple, Dict, Any + +import requests +from forte.data.ontology import Annotation +from forte.processors.data_augment.algorithms.single_annotation_op import ( + SingleAnnotationAugmentOp, +) +from forte.common.configuration import Config + +__all__ = [ + "AbbreviationReplacementOp", +] + + +class AbbreviationReplacementOp(SingleAnnotationAugmentOp): + r""" + This class is a replacement op utilizing a pre-defined + abbreviation dictionary to replace word or phrase + with an abbreviation. The abbreviation dictionary can + be user-defined, we also provide a default dictionary. + `prob` indicates the probability of replacement. + """ + + def __init__(self, configs: Config): + super().__init__(configs) + + dict_path = configs["dict_path"] + + try: + r = requests.get(dict_path) + self.data = r.json() + except requests.exceptions.RequestException: + with open(dict_path, encoding="utf8") as json_file: + self.data = json.load(json_file) + + def single_annotation_augment( + self, input_anno: Annotation + ) -> Tuple[bool, str]: + r""" + This function replaces a phrase from an abbreviation dictionary + with `prob` as the probability of replacement. + If the input phrase does not have a corresponding phrase in the + dictionary, no replacement will happen, return False. + + Args: + input_anno: The input annotation, could be a word or phrase. + + Returns: + A tuple, where the first element is a boolean value indicating + whether the replacement happens, and the second element is the + replaced string. + + """ + # If the replacement does not happen, return False. + if random.random() > self.configs.prob: + return False, input_anno.text + if input_anno.text in self.data.keys(): + result: str = self.data[input_anno.text] + return True, result + else: + return False, input_anno.text + + @classmethod + def default_configs(cls) -> Dict[str, Any]: + r""" + Returns: + A dictionary with the default config for this processor. + Following are the keys for this dictionary: + + - prob: The probability of replacement, + should fall in [0, 1]. Default value is 0.5. + + - dict_path: the `url` or the path to the pre-defined + abbreviation json file. The key is a word / phrase we want + to replace. The value is an abbreviated word of the + corresponding key. Default dictionary is from a web-scraped + slang dictionary + ("https://github.com/abbeyyyy/JsonFiles/blob/main/abbreviate.json"). + + """ + return { + "augment_entry": "ft.onto.base_ontology.Phrase", + "other_entry_policy": { + "ft.onto.base_ontology.Phrase": "auto_align", + }, + "dict_path": "https://raw.githubusercontent.com/abbeyyyy/" + "JsonFiles/main/abbreviate.json", + "prob": 0.5, + } diff --git a/tests/forte/processors/data_augment/algorithms/abbreviation_replacement_op_test.py b/tests/forte/processors/data_augment/algorithms/abbreviation_replacement_op_test.py new file mode 100644 index 000000000..6e905a72d --- /dev/null +++ b/tests/forte/processors/data_augment/algorithms/abbreviation_replacement_op_test.py @@ -0,0 +1,86 @@ +# Copyright 2022 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for dictionary word replacement op. +""" + +import unittest +from forte.data.data_pack import DataPack +from ft.onto.base_ontology import Phrase +from forte.processors.data_augment.algorithms.abbreviation_replacement_op import ( + AbbreviationReplacementOp, +) + + +class TestAbbreviationReplacementOp(unittest.TestCase): + def setUp(self): + self.abre = AbbreviationReplacementOp( + configs={ + "dict_path": "https://raw.githubusercontent.com/abbeyyyy/" + "JsonFiles/main/abbreviate.json", + "prob": 1.0, + } + ) + + def test_replace(self): + data_pack_1 = DataPack() + text_1 = "I will see you later!" + data_pack_1.set_text(text_1) + phrase_1 = Phrase(data_pack_1, 7, len(text_1) - 1) + data_pack_1.add_entry(phrase_1) + + augmented_data_pack_1 = self.abre.perform_augmentation(data_pack_1) + augmented_phrase_1 = list( + augmented_data_pack_1.get("ft.onto.base_ontology.Phrase") + )[0] + + self.assertIn( + augmented_phrase_1.text, + ["syl8r", "cul83r", "cul8r"], + ) + + # Empty phrase + data_pack_2 = DataPack() + data_pack_2.set_text(text_1) + phrase_2 = Phrase(data_pack_2, 0, 0) + data_pack_2.add_entry(phrase_2) + + augmented_data_pack_2 = self.abre.perform_augmentation(data_pack_2) + augmented_phrase_2 = list( + augmented_data_pack_2.get("ft.onto.base_ontology.Phrase") + )[0] + + self.assertIn( + augmented_phrase_2.text, + [""], + ) + + # no abbreviation exist + data_pack_3 = DataPack() + data_pack_3.set_text(text_1) + phrase_3 = Phrase(data_pack_3, 2, 6) + data_pack_3.add_entry(phrase_3) + + augmented_data_pack_3 = self.abre.perform_augmentation(data_pack_3) + augmented_phrase_3 = list( + augmented_data_pack_3.get("ft.onto.base_ontology.Phrase") + )[0] + + self.assertIn( + augmented_phrase_3.text, + ["will"], + ) + +if __name__ == "__main__": + unittest.main()