From 3220e850e722b73d873460abb585f132f10ff944 Mon Sep 17 00:00:00 2001 From: husein zolkepli Date: Mon, 24 Aug 2020 14:40:42 +0800 Subject: [PATCH 1/2] fix ctc symbol modalities --- tensor2tensor/layers/modalities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 2fdd24eb1..948bc3f6d 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -635,7 +635,7 @@ def video_raw_targets_bottom(x, model_hparams, vocab_size): # Loss transformations, applied to target features -def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weight_fn): +def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weights_fn): """Compute the CTC loss.""" del model_hparams, vocab_size # unused arg logits = top_out @@ -658,7 +658,7 @@ def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weight_fn): time_major=False, preprocess_collapse_repeated=False, ctc_merge_repeated=False) - weights = weight_fn(targets) + weights = weights_fn(targets) return tf.reduce_sum(xent), tf.reduce_sum(weights) From fae15958569752688422bdcaf04fa772f4e6d658 Mon Sep 17 00:00:00 2001 From: husein zolkepli Date: Mon, 24 Aug 2020 14:50:13 +0800 Subject: [PATCH 2/2] fix ctc symbol modalities, bump --- tensor2tensor/layers/modalities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 948bc3f6d..e05e13319 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -637,6 +637,7 @@ def video_raw_targets_bottom(x, model_hparams, vocab_size): def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weights_fn): """Compute the CTC loss.""" + del model_hparams, vocab_size # unused arg logits = top_out with tf.name_scope("ctc_loss", values=[logits, targets]):