diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index c3006a3b3..339a4c1c6 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -34,16 +34,17 @@ TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals']) -class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'contrastive_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): +class EpochStats(namedtuple("EpochStats", ['transition_loss', 'contrastive_loss', 'total_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): def __add__(self, other): transitions_correct = self.transitions_correct + other.transitions_correct transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect repairs_used = self.repairs_used + other.repairs_used fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used - epoch_loss = self.epoch_loss + other.epoch_loss + transition_loss = self.transition_loss + other.transition_loss contrastive_loss = self.contrastive_loss + other.contrastive_loss + total_loss = self.total_loss + other.total_loss nans = self.nans + other.nans - return EpochStats(epoch_loss, contrastive_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + return EpochStats(transition_loss, contrastive_loss, total_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def evaluate(args, model_file, retag_pipeline): """ @@ -435,13 +436,14 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d tlogger.warning("Had to ignore %d batches with NaN", epoch_stats.nans) stats_log_lines = [ "Epoch %d finished" % trainer.epochs_trained, - "Transitions correct: %s" % epoch_stats.transitions_correct, + "Transitions correct: %s" % epoch_stats.transitions_correct, "Transitions incorrect: %s" % epoch_stats.transitions_incorrect, - "Total loss for epoch: %.5f" % epoch_stats.epoch_loss, + "Transition loss for epoch: %.5f" % epoch_stats.transition_loss, ] if args['contrastive_learning_rate'] > 0.0: stats_log_lines.extend([ - "Contrastive loss for epoch: %.5f" % epoch_stats.contrastive_loss + "Contrastive loss for epoch: %.5f" % epoch_stats.contrastive_loss, + "Total loss for epoch: %.5f" % epoch_stats.total_loss, ]) stats_log_lines.extend([ "Dev score (%5d): %8f" % (trainer.epochs_trained, f1), @@ -456,7 +458,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d tlogger.info("Updating learning rate from %f to %f", old_lr, new_lr) if args['wandb']: - wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained) + wandb.log({'total_loss': epoch_stats.total_loss, 'dev_score': f1}, step=trainer.epochs_trained) if args['wandb_norm_regex']: watch_regex = re.compile(args['wandb_norm_regex']) for n, p in trainer.model.named_parameters(): @@ -546,7 +548,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m optimizer = trainer.optimizer - epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0) + epoch_stats = EpochStats(0.0, 0.0, 0.0, Counter(), Counter(), Counter(), 0, 0) for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)): batch = epoch_data[interval_start:interval_start+args['train_batch_size']] @@ -694,9 +696,9 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te answers = torch.cat(all_answers) errors = process_outputs(errors) - tree_loss = model_loss_function(errors, answers) - tree_loss += contrastive_loss - tree_loss.backward() + transition_loss = model_loss_function(errors, answers) + total_loss = transition_loss + contrastive_loss + total_loss.backward() if args['watch_regex']: matched = False tlogger.info("Watching %s ... epoch %d batch %d", args['watch_regex'], epoch, batch_idx) @@ -712,17 +714,19 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te tlogger.info(" %s norm: %f grad not required", n, torch.linalg.norm(p)) if not matched: tlogger.info(" (none found!)") - if torch.any(torch.isnan(tree_loss)): - batch_loss = 0.0 + if torch.any(torch.isnan(total_loss)): + total_loss = 0.0 + transition_loss = 0.0 contrastive_loss = 0.0 nans = 1 else: - batch_loss = tree_loss.item() + total_loss = total_loss.item() + transition_loss = transition_loss.item() if not isinstance(contrastive_loss, float): contrastive_loss = contrastive_loss.item() nans = 0 - return EpochStats(batch_loss, contrastive_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + return EpochStats(transition_loss, contrastive_loss, total_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None): """