diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index 8b8cc1335..2cfb19309 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -277,6 +277,20 @@ def test_early_dropout(self, wordvec_pretrain_file): if all(module.p == 0.0 for _, module in dropouts): raise AssertionError("All dropouts were 0 after training even though early_dropout was set to -1") + def test_contrastive(self, wordvec_pretrain_file): + """ + Test that things don't blow up when a contrastive loss is used for a few iterations + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = ['--contrastive_learning_rate', '0.1'] + self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) + + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + # TODO: get some kind of loss record back from the training process + # so that we can check it is being properly applied? + args = ['--contrastive_learning_rate', '0.1', '--contrastive_initial_epoch', '3'] + self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) + def test_train_silver(self, wordvec_pretrain_file): """ Test the whole thing for a few iterations on the fake data