Skip to content

Commit

Permalink
Add a prototype of orthogonal loss for neighboring subtrees.
Browse files Browse the repository at this point in the history
Need to add crossing subtrees and possibly more members of the same subtree as orthogonal losses.

Uses MSELoss over the dot product to add the orthogonal loss
  • Loading branch information
AngledLuffa committed Dec 5, 2024
1 parent 449feae commit d372541
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
75 changes: 66 additions & 9 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@

TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])

class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'orthogonal_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
def __add__(self, other):
transitions_correct = self.transitions_correct + other.transitions_correct
transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
repairs_used = self.repairs_used + other.repairs_used
fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used
epoch_loss = self.epoch_loss + other.epoch_loss
orthogonal_loss = self.orthogonal_loss + other.orthogonal_loss
nans = self.nans + other.nans
return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(epoch_loss, orthogonal_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)

def evaluate(args, model_file, retag_pipeline):
"""
Expand Down Expand Up @@ -335,6 +336,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
add the errors to the list of things to backprop
advance the parsing state for each of the trees
"""
device = trainer.device

# Somewhat unusual, but possibly related to the extreme variability in length of trees
# Various experiments generally show about 0.5 F1 loss on various
# datasets when using 'mean' instead of 'sum' for reduction
Expand All @@ -360,6 +363,19 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d

device = trainer.device
model_loss_function.to(device)

if args['orthogonal_learning_rate'] > 0:
# To measure which of the neighboring subtrees are orthogonal,
# we take the dot product of their embeddings and then try to
# force that to 0 using MSELoss
# That way, either positive or negative dot product
# (which represents some form of correlation)
# is penalized
orthogonal_loss_function = nn.MSELoss(reduction='sum')
orthogonal_loss_function.to(device)
else:
orthogonal_loss_function = None

transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0)
for (y, x) in enumerate(trainer.transitions)}
trainer.train()
Expand Down Expand Up @@ -409,7 +425,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
epoch_data = epoch_data + epoch_silver_data
epoch_data.sort(key=lambda x: len(x[1]))

epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args)
epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, orthogonal_loss_function, epoch_data, oracle, args)

# print statistics
# by now we've forgotten about the original tags on the trees,
Expand All @@ -429,10 +445,14 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
"Epoch %d finished" % trainer.epochs_trained,
"Transitions correct: %s" % epoch_stats.transitions_correct,
"Transitions incorrect: %s" % epoch_stats.transitions_incorrect,
]
if args['orthogonal_learning_rate'] > 0.0 and args['orthogonal_initial_epoch'] <= trainer.epochs_trained:
stats_log_lines.append("Orthogonal loss for epoch: %.5f" % epoch_stats.orthogonal_loss)
stats_log_lines.extend([
"Total loss for epoch: %.5f" % epoch_stats.epoch_loss,
"Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
"Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1)
]
])
tlogger.info("\n ".join(stats_log_lines))

old_lr = trainer.optimizer.param_groups[0]['lr']
Expand Down Expand Up @@ -519,17 +539,17 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d

return trainer

def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args):
def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, orthogonal_loss_function, epoch_data, oracle, args):
interval_starts = list(range(0, len(epoch_data), args['train_batch_size']))
random.shuffle(interval_starts)

optimizer = trainer.optimizer

epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0)
epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0)

for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args)
batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, orthogonal_loss_function, oracle, args)
trainer.batches_trained += 1

# Early in the training, some trees will be degenerate in a
Expand All @@ -555,7 +575,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m

return epoch_stats

def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args):
def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, orthogonal_loss_function, oracle, args):
"""
Train the model for one batch
Expand Down Expand Up @@ -651,8 +671,42 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
errors = torch.cat(all_errors)
answers = torch.cat(all_answers)

orthogonal_loss = 0.0
if epoch >= args['orthogonal_initial_epoch'] and orthogonal_loss_function is not None:
gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False)
orthogonal_losses = []
def build_losses(con_values, tree):
# this can skip preterminals
# but a preterminal in the middle of a phrase has a high chance of being
# a conjunction, a punctuation, or other non-function word anyway
subtrees = [x for x in tree.children if not x.is_preterminal()]
for subtree in subtrees:
build_losses(con_values, subtree)
for subtree_idx in range(len(subtrees)-1):
left = str(subtrees[subtree_idx])
right = str(subtrees[subtree_idx+1])
if left in con_values and right in con_values:
left_value = con_values[left].squeeze(0)
right_value = con_values[right].squeeze(0)
# technically could divide by norm to get the angle
# not sure that would help
# training dot product to be 0 is already enforcing orthogonal
mse = torch.dot(left_value, right_value)
orthogonal_losses.append(mse)
for result in gold_results:
gold_constituents = result.constituents
con_values = {}
for con in gold_constituents:
con_values[str(con.value)] = con.tree_hx
build_losses(con_values, result.gold)
orthogonal_losses = torch.stack(orthogonal_losses)
orthogonal_target = torch.zeros(orthogonal_losses.shape).to(orthogonal_losses.device)
orthogonal_loss = orthogonal_loss_function(orthogonal_losses, orthogonal_target) * args['orthogonal_learning_rate']


errors = process_outputs(errors)
tree_loss = model_loss_function(errors, answers)
tree_loss += orthogonal_loss
tree_loss.backward()
if args['watch_regex']:
matched = False
Expand All @@ -670,13 +724,16 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
if not matched:
tlogger.info(" (none found!)")
if torch.any(torch.isnan(tree_loss)):
orthogonal_loss = 0.0
batch_loss = 0.0
nans = 1
else:
batch_loss = tree_loss.item()
if not isinstance(orthogonal_loss, float):
orthogonal_loss = orthogonal_loss.item()
nans = 0

return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(batch_loss, orthogonal_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)

def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None):
"""
Expand Down
3 changes: 3 additions & 0 deletions stanza/models/constituency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,9 @@ def build_argparse():
parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping')
parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')

parser.add_argument('--orthogonal_initial_epoch', default=1, type=int, help='When to start using the orthogonal loss')
parser.add_argument('--orthogonal_learning_rate', default=0.0, type=float, help='Multiplicative factor for the orthogonal loss')

# Large Margin is from Large Margin In Softmax Cross-Entropy Loss
# it did not help on an Italian VIT test
# scores went from 0.8252 to 0.8248
Expand Down

0 comments on commit d372541

Please sign in to comment.