Skip to content

Commit

Permalink
Convert the MWT training to use a pytorch dataloader with shuffling
Browse files Browse the repository at this point in the history
In theory this should also provide some cpu/gpu parallelism at test time, although we haven't done anything to ensure it is using multiprocessing

Fix the max_steps count by counting batches, not samples
  • Loading branch information
AngledLuffa committed Nov 29, 2024
1 parent fd74bf8 commit 6c47374
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 37 deletions.
81 changes: 50 additions & 31 deletions stanza/models/mwt/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import random
import numpy as np
import os
from collections import Counter
from collections import Counter, namedtuple
import logging

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader as DL

import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
Expand All @@ -14,6 +16,9 @@

logger = logging.getLogger('stanza')

DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text")
DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx")

# enforce that the MWT splitter knows about a couple different alternate apostrophes
# including covering some potential " typos
# setting the augmentation to a very low value should be enough to teach it
Expand All @@ -22,7 +27,6 @@
# 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07
APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''')

# TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS
class DataLoader:
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):
self.batch_size = batch_size
Expand Down Expand Up @@ -56,12 +60,9 @@ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_u
indices = list(range(len(data)))
random.shuffle(indices)
data = [data[i] for i in indices]
self.num_examples = len(data)

# chunk into batches
data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
self.data = data
logger.debug("{} batches created.".format(len(data)))
self.num_examples = len(data)

def init_vocab(self, data):
assert self.evaluation == False # for eval vocab must exist
Expand All @@ -77,17 +78,14 @@ def maybe_augment_apos(self, datum):
break
return datum


def process(self, data):
processed = []
for d in data:
if not self.evaluation and self.augment_apos > 0:
d = self.maybe_augment_apos(d)
src = list(d[0])
src = [constant.SOS] + src + [constant.EOS]
tgt_in, tgt_out = self.prepare_target(self.vocab, d)
src = self.vocab.map(src)
processed += [[src, tgt_in, tgt_out, d[0]]]
def process(self, sample):
if not self.evaluation and self.augment_apos > 0:
sample = self.maybe_augment_apos(sample)
src = list(sample[0])
src = [constant.SOS] + src + [constant.EOS]
tgt_in, tgt_out = self.prepare_target(self.vocab, sample)
src = self.vocab.map(src)
processed = [src, tgt_in, tgt_out, sample[0]]
return processed

def prepare_target(self, vocab, datum):
Expand All @@ -108,31 +106,52 @@ def __getitem__(self, key):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch = self.process(batch)
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 4
sample = self.data[key]
sample = self.process(sample)
assert len(sample) == 4

src = torch.tensor(sample[0])
tgt_in = torch.tensor(sample[1])
tgt_out = torch.tensor(sample[2])
orig_text = sample[3]
result = DataSample(src, tgt_in, tgt_out, orig_text), key
return result

# sort all fields by lens for easy RNN operations
lens = [len(x) for x in batch[0]]
batch, orig_idx = sort_all(batch, lens)
@staticmethod
def __collate_fn(data):
(data, idx) = zip(*data)
(src, tgt_in, tgt_out, orig_text) = zip(*data)

# collate_fn is given a list of length batch size
batch_size = len(data)

lens = [len(x) for x in tgt_in]
(src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens)
lens = [len(x) for x in tgt_in]

# convert to tensors
src = batch[0]
src = get_long_tensor(src, batch_size)
src = pad_sequence(src, True, constant.PAD_ID)
src_mask = torch.eq(src, constant.PAD_ID)
tgt_in = get_long_tensor(batch[1], batch_size)
tgt_out = get_long_tensor(batch[2], batch_size)
orig_text = batch[3]
tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID)
tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID)
assert tgt_in.size(1) == tgt_out.size(1), \
"Target input and output sequence sizes do not match."
return (src, src_mask, tgt_in, tgt_out, orig_text, orig_idx)
return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx)

def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)

def to_loader(self):
"""Converts self to a DataLoader """

batch_size = self.batch_size
shuffle = not self.evaluation
return DL(self,
collate_fn=self.__collate_fn,
batch_size=batch_size,
shuffle=shuffle)

def load_doc(self, doc, evaluation=False):
data = doc.get_mwt_expansions(evaluation)
if evaluation: data = [[e] for e in data]
Expand Down
10 changes: 6 additions & 4 deletions stanza/models/mwt_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import datetime
import argparse
import logging
import math
import numpy as np
import random
import torch
Expand Down Expand Up @@ -184,7 +185,8 @@ def train(args):
# train a seq2seq model
logger.info("Training seq2seq-based MWT expander...")
global_step = 0
max_steps = len(train_batch) * args['num_epoch']
steps_per_epoch = math.ceil(len(train_batch) / args['batch_size'])
max_steps = steps_per_epoch * args['num_epoch']
dev_score_history = []
best_dev_preds = []
current_lr = args['lr']
Expand All @@ -201,7 +203,7 @@ def train(args):
# start training
for epoch in range(1, args['num_epoch']+1):
train_loss = 0
for i, batch in enumerate(train_batch):
for i, batch in enumerate(train_batch.to_loader()):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
Expand All @@ -218,7 +220,7 @@ def train(args):
# eval on dev
logger.info("Evaluating on dev set...")
dev_preds = []
for i, batch in enumerate(dev_batch):
for i, batch in enumerate(dev_batch.to_loader()):
preds = trainer.predict(batch)
dev_preds += preds
if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False):
Expand Down Expand Up @@ -296,7 +298,7 @@ def evaluate(args):
else:
logger.info("Running the seq2seq model...")
preds = []
for i, b in enumerate(batch):
for i, b in enumerate(batch.to_loader()):
preds += trainer.predict(b)

if loaded_args.get('ensemble_dict', False):
Expand Down
2 changes: 1 addition & 1 deletion stanza/pipeline/mwt_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def process(self, document):
else:
with torch.no_grad():
preds = []
for i, b in enumerate(batch):
for i, b in enumerate(batch.to_loader()):
preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab)

if self.config.get('ensemble_dict', False):
Expand Down
2 changes: 1 addition & 1 deletion stanza/tests/mwt/test_character_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_train(tmp_path):
doc = CoNLL.conll2doc(input_str=ENG_DEV)
dataloader = DataLoader(doc, 10, model.args, vocab=model.vocab, evaluation=True, expand_unk_vocab=True)
preds = []
for i, batch in enumerate(dataloader):
for i, batch in enumerate(dataloader.to_loader()):
assert i == 0 # there should only be one batch
preds += model.predict(batch, never_decode_unk=True, vocab=dataloader.vocab)
assert len(preds) == 1
Expand Down

0 comments on commit 6c47374

Please sign in to comment.