From 30c4f837e5385927f52e107e9c22255184120b81 Mon Sep 17 00:00:00 2001 From: Meysam Moqaddam <65253484+MeysamMoghaddam@users.noreply.github.com> Date: Tue, 1 Apr 2025 08:00:35 +0330 Subject: [PATCH] Fix incorrect reference to model in Transformer class fix: correct model reference to self in Transformer class Changed model.compute_loss() to self.compute_loss() in train_step and test_step methods to properly reference the class instance. --- examples/audio/transformer_asr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/audio/transformer_asr.py b/examples/audio/transformer_asr.py index f7b1d7130e..b661c73c83 100644 --- a/examples/audio/transformer_asr.py +++ b/examples/audio/transformer_asr.py @@ -247,7 +247,7 @@ def train_step(self, batch): preds = self([source, dec_input]) one_hot = tf.one_hot(dec_target, depth=self.num_classes) mask = tf.math.logical_not(tf.math.equal(dec_target, 0)) - loss = model.compute_loss(None, one_hot, preds, sample_weight=mask) + loss = self.compute_loss(None, one_hot, preds, sample_weight=mask) trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) self.optimizer.apply_gradients(zip(gradients, trainable_vars)) @@ -262,7 +262,7 @@ def test_step(self, batch): preds = self([source, dec_input]) one_hot = tf.one_hot(dec_target, depth=self.num_classes) mask = tf.math.logical_not(tf.math.equal(dec_target, 0)) - loss = model.compute_loss(None, one_hot, preds, sample_weight=mask) + loss = self.compute_loss(None, one_hot, preds, sample_weight=mask) self.loss_metric.update_state(loss) return {"loss": self.loss_metric.result()}