Skip to content

Commit

Permalink
When training, take advantage of CorrectForm annotations on the words
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Dec 21, 2024
1 parent d693c2a commit dbdf429
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
43 changes: 41 additions & 2 deletions stanza/models/lemma/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,52 @@ def load_doc(doc, caseless, evaluation):
if evaluation:
data = doc.get([TEXT, UPOS, LEMMA])
else:
data = doc.get([TEXT, UPOS, LEMMA, HEAD, DEPREL], as_sentences=True)
data = doc.get([TEXT, UPOS, LEMMA, HEAD, DEPREL, MISC], as_sentences=True)
data = DataLoader.remove_goeswith(data)
data = DataLoader.extract_correct_forms(data)
data = DataLoader.resolve_none(data)
if caseless:
data = DataLoader.lowercase_data(data)
return data

@staticmethod
def extract_correct_forms(data):
"""
Here we go through the raw data and use the CorrectForm of words tagged with CorrectForm
In addition, if the incorrect form of the word is not present in the training data,
we keep the incorrect form for the lemmatizer to learn from.
This way, it can occasionally get things right in misspelled input text.
We do check for and eliminate words where the incorrect form is already known as the
lemma for a different word. For example, in the English datasets, there is a "busy"
which was meant to be "buys", and we don't want the model to learn to lemmatize "busy" to "buy"
"""
new_data = []
incorrect_forms = []
for word in data:
misc = word[-1]
if not misc:
new_data.append(word[:3])
continue
misc = misc.split("|")
for piece in misc:
if piece.startswith("CorrectForm="):
cf = piece.split("=", maxsplit=1)[1]
# treat the CorrectForm as the desired word
new_data.append((cf, word[1], word[2]))
# and save the broken one for later in case it wasn't used anywhere else
incorrect_forms.append((cf, word))
break
else:
# if no CorrectForm, just keep the word as normal
new_data.append(word[:3])
known_words = {x[0] for x in new_data}
for correct_form, word in incorrect_forms:
if word[0] not in known_words:
new_data.append(word[:3])
return new_data

@staticmethod
def remove_goeswith(data):
"""
Expand All @@ -154,7 +193,7 @@ def remove_goeswith(data):
if word[4] == 'goeswith':
remove_indices.add(word_idx)
remove_indices.add(word[3]-1)
filtered_data.extend([x[:3] for idx, x in enumerate(sentence) if idx not in remove_indices])
filtered_data.extend([x for idx, x in enumerate(sentence) if idx not in remove_indices])
return filtered_data

@staticmethod
Expand Down
23 changes: 23 additions & 0 deletions stanza/tests/lemma/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@
""".lstrip()

CORRECT_FORM_DATA = """
# sent_id = weblog-blogspot.com_healingiraq_20040409053012_ENG_20040409_053012-0019
# text = They are targetting ambulances
1 They they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 3 nsubj 3:nsubj _
2 are be AUX VBP Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
3 targetting target VERB VBG Tense=Pres|Typo=Yes|VerbForm=Part 0 root 0:root CorrectForm=targeting
4 ambulances ambulance NOUN NNS Number=Plur 3 obj 3:obj SpaceAfter=No
"""


def test_load_document():
train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
Expand All @@ -81,3 +91,16 @@ def test_load_goeswith():
assert len(data) == 33 # will be the same as in test_load_document, but with the trailing 3 GOESWITH removed
assert all(len(x) == 3 for x in data)

def test_correct_form():
raw_data = TRAIN_DATA + CORRECT_FORM_DATA
train_doc = CoNLL.conll2doc(input_str=raw_data)
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
assert len(data) == 37
# the 'targeting' correction should not be applied if evaluation=True
# when evaluation=False, then the CorrectForms will be applied
assert not any(x[0] == 'targeting' for x in data)

data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
assert len(data) == 38 # the same, but with an extra row so the model learns both 'targetting' and 'targeting'
assert any(x[0] == 'targeting' for x in data)
assert any(x[0] == 'targetting' for x in data)

0 comments on commit dbdf429

Please sign in to comment.