Skip to content

Commit

Permalink
Separate out transitions only loss from the total loss reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Dec 17, 2024
1 parent 5bf876a commit 276f15c
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@

TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])

class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'contrastive_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
class EpochStats(namedtuple("EpochStats", ['transition_loss', 'contrastive_loss', 'total_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
def __add__(self, other):
transitions_correct = self.transitions_correct + other.transitions_correct
transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
repairs_used = self.repairs_used + other.repairs_used
fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used
epoch_loss = self.epoch_loss + other.epoch_loss
transition_loss = self.transition_loss + other.transition_loss
contrastive_loss = self.contrastive_loss + other.contrastive_loss
total_loss = self.total_loss + other.total_loss
nans = self.nans + other.nans
return EpochStats(epoch_loss, contrastive_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(transition_loss, contrastive_loss, total_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)

def evaluate(args, model_file, retag_pipeline):
"""
Expand Down Expand Up @@ -435,13 +436,14 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
tlogger.warning("Had to ignore %d batches with NaN", epoch_stats.nans)
stats_log_lines = [
"Epoch %d finished" % trainer.epochs_trained,
"Transitions correct: %s" % epoch_stats.transitions_correct,
"Transitions correct: %s" % epoch_stats.transitions_correct,
"Transitions incorrect: %s" % epoch_stats.transitions_incorrect,
"Total loss for epoch: %.5f" % epoch_stats.epoch_loss,
"Transition loss for epoch: %.5f" % epoch_stats.transition_loss,
]
if args['contrastive_learning_rate'] > 0.0:
stats_log_lines.extend([
"Contrastive loss for epoch: %.5f" % epoch_stats.contrastive_loss
"Contrastive loss for epoch: %.5f" % epoch_stats.contrastive_loss,
"Total loss for epoch: %.5f" % epoch_stats.total_loss,
])
stats_log_lines.extend([
"Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
Expand All @@ -456,7 +458,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
tlogger.info("Updating learning rate from %f to %f", old_lr, new_lr)

if args['wandb']:
wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained)
wandb.log({'total_loss': epoch_stats.total_loss, 'dev_score': f1}, step=trainer.epochs_trained)
if args['wandb_norm_regex']:
watch_regex = re.compile(args['wandb_norm_regex'])
for n, p in trainer.model.named_parameters():
Expand Down Expand Up @@ -546,7 +548,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m

optimizer = trainer.optimizer

epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0)
epoch_stats = EpochStats(0.0, 0.0, 0.0, Counter(), Counter(), Counter(), 0, 0)

for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
Expand Down Expand Up @@ -694,9 +696,9 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
answers = torch.cat(all_answers)

errors = process_outputs(errors)
tree_loss = model_loss_function(errors, answers)
tree_loss += contrastive_loss
tree_loss.backward()
transition_loss = model_loss_function(errors, answers)
total_loss = transition_loss + contrastive_loss
total_loss.backward()
if args['watch_regex']:
matched = False
tlogger.info("Watching %s ... epoch %d batch %d", args['watch_regex'], epoch, batch_idx)
Expand All @@ -712,17 +714,19 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
tlogger.info(" %s norm: %f grad not required", n, torch.linalg.norm(p))
if not matched:
tlogger.info(" (none found!)")
if torch.any(torch.isnan(tree_loss)):
batch_loss = 0.0
if torch.any(torch.isnan(total_loss)):
total_loss = 0.0
transition_loss = 0.0
contrastive_loss = 0.0
nans = 1
else:
batch_loss = tree_loss.item()
total_loss = total_loss.item()
transition_loss = transition_loss.item()
if not isinstance(contrastive_loss, float):
contrastive_loss = contrastive_loss.item()
nans = 0

return EpochStats(batch_loss, contrastive_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(transition_loss, contrastive_loss, total_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)

def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None):
"""
Expand Down

0 comments on commit 276f15c

Please sign in to comment.