From d37254119830a7573cdaadb0f98b1cda9227d8f3 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 4 Dec 2024 02:42:11 -0800 Subject: [PATCH] Add a prototype of orthogonal loss for neighboring subtrees. Need to add crossing subtrees and possibly more members of the same subtree as orthogonal losses. Uses MSELoss over the dot product to add the orthogonal loss --- stanza/models/constituency/parser_training.py | 75 ++++++++++++++++--- stanza/models/constituency_parser.py | 3 + 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index b133045a1..0d8171c41 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -34,15 +34,16 @@ TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals']) -class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): +class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'orthogonal_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): def __add__(self, other): transitions_correct = self.transitions_correct + other.transitions_correct transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect repairs_used = self.repairs_used + other.repairs_used fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used epoch_loss = self.epoch_loss + other.epoch_loss + orthogonal_loss = self.orthogonal_loss + other.orthogonal_loss nans = self.nans + other.nans - return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + return EpochStats(epoch_loss, orthogonal_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def evaluate(args, model_file, retag_pipeline): """ @@ -335,6 +336,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d add the errors to the list of things to backprop advance the parsing state for each of the trees """ + device = trainer.device + # Somewhat unusual, but possibly related to the extreme variability in length of trees # Various experiments generally show about 0.5 F1 loss on various # datasets when using 'mean' instead of 'sum' for reduction @@ -360,6 +363,19 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d device = trainer.device model_loss_function.to(device) + + if args['orthogonal_learning_rate'] > 0: + # To measure which of the neighboring subtrees are orthogonal, + # we take the dot product of their embeddings and then try to + # force that to 0 using MSELoss + # That way, either positive or negative dot product + # (which represents some form of correlation) + # is penalized + orthogonal_loss_function = nn.MSELoss(reduction='sum') + orthogonal_loss_function.to(device) + else: + orthogonal_loss_function = None + transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0) for (y, x) in enumerate(trainer.transitions)} trainer.train() @@ -409,7 +425,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d epoch_data = epoch_data + epoch_silver_data epoch_data.sort(key=lambda x: len(x[1])) - epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args) + epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, orthogonal_loss_function, epoch_data, oracle, args) # print statistics # by now we've forgotten about the original tags on the trees, @@ -429,10 +445,14 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d "Epoch %d finished" % trainer.epochs_trained, "Transitions correct: %s" % epoch_stats.transitions_correct, "Transitions incorrect: %s" % epoch_stats.transitions_incorrect, + ] + if args['orthogonal_learning_rate'] > 0.0 and args['orthogonal_initial_epoch'] <= trainer.epochs_trained: + stats_log_lines.append("Orthogonal loss for epoch: %.5f" % epoch_stats.orthogonal_loss) + stats_log_lines.extend([ "Total loss for epoch: %.5f" % epoch_stats.epoch_loss, "Dev score (%5d): %8f" % (trainer.epochs_trained, f1), "Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1) - ] + ]) tlogger.info("\n ".join(stats_log_lines)) old_lr = trainer.optimizer.param_groups[0]['lr'] @@ -519,17 +539,17 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d return trainer -def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args): +def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, orthogonal_loss_function, epoch_data, oracle, args): interval_starts = list(range(0, len(epoch_data), args['train_batch_size'])) random.shuffle(interval_starts) optimizer = trainer.optimizer - epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0) + epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0) for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)): batch = epoch_data[interval_start:interval_start+args['train_batch_size']] - batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args) + batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, orthogonal_loss_function, oracle, args) trainer.batches_trained += 1 # Early in the training, some trees will be degenerate in a @@ -555,7 +575,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m return epoch_stats -def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args): +def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, orthogonal_loss_function, oracle, args): """ Train the model for one batch @@ -651,8 +671,42 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te errors = torch.cat(all_errors) answers = torch.cat(all_answers) + orthogonal_loss = 0.0 + if epoch >= args['orthogonal_initial_epoch'] and orthogonal_loss_function is not None: + gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False) + orthogonal_losses = [] + def build_losses(con_values, tree): + # this can skip preterminals + # but a preterminal in the middle of a phrase has a high chance of being + # a conjunction, a punctuation, or other non-function word anyway + subtrees = [x for x in tree.children if not x.is_preterminal()] + for subtree in subtrees: + build_losses(con_values, subtree) + for subtree_idx in range(len(subtrees)-1): + left = str(subtrees[subtree_idx]) + right = str(subtrees[subtree_idx+1]) + if left in con_values and right in con_values: + left_value = con_values[left].squeeze(0) + right_value = con_values[right].squeeze(0) + # technically could divide by norm to get the angle + # not sure that would help + # training dot product to be 0 is already enforcing orthogonal + mse = torch.dot(left_value, right_value) + orthogonal_losses.append(mse) + for result in gold_results: + gold_constituents = result.constituents + con_values = {} + for con in gold_constituents: + con_values[str(con.value)] = con.tree_hx + build_losses(con_values, result.gold) + orthogonal_losses = torch.stack(orthogonal_losses) + orthogonal_target = torch.zeros(orthogonal_losses.shape).to(orthogonal_losses.device) + orthogonal_loss = orthogonal_loss_function(orthogonal_losses, orthogonal_target) * args['orthogonal_learning_rate'] + + errors = process_outputs(errors) tree_loss = model_loss_function(errors, answers) + tree_loss += orthogonal_loss tree_loss.backward() if args['watch_regex']: matched = False @@ -670,13 +724,16 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te if not matched: tlogger.info(" (none found!)") if torch.any(torch.isnan(tree_loss)): + orthogonal_loss = 0.0 batch_loss = 0.0 nans = 1 else: batch_loss = tree_loss.item() + if not isinstance(orthogonal_loss, float): + orthogonal_loss = orthogonal_loss.item() nans = 0 - return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + return EpochStats(batch_loss, orthogonal_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None): """ diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index b5b8b7980..bff55084a 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -556,6 +556,9 @@ def build_argparse(): parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping') parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping') + parser.add_argument('--orthogonal_initial_epoch', default=1, type=int, help='When to start using the orthogonal loss') + parser.add_argument('--orthogonal_learning_rate', default=0.0, type=float, help='Multiplicative factor for the orthogonal loss') + # Large Margin is from Large Margin In Softmax Cross-Entropy Loss # it did not help on an Italian VIT test # scores went from 0.8252 to 0.8248