Skip to content

Commit 011d460

Browse files
authored
[speech-commands] Fix a flaky test and a wrong console message (#143)
- Fixes tensorflow/tfjs#1192 The previous assumption was that the kernel of the dense layer gets updated. But sometimes, the bias updates and the kernel doesn't. - A console.log message regarding large datasets and the use of fitDataset() was placed in the wrong place. This PR fixes that as well.
1 parent 053bd2d commit 011d460

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

speech-commands/src/browser_fft_recognizer.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -917,12 +917,12 @@ class TransferBrowserFftSpeechCommandRecognizer extends
917917
const datasetDurationMillisThreshold =
918918
config.fitDatasetDurationMillisThreshold == null ?
919919
60e3 : config.fitDatasetDurationMillisThreshold;
920-
console.log(
921-
`Detected large dataset: total duration = ` +
922-
`${this.dataset.durationMillis()} ms > ` +
923-
`${datasetDurationMillisThreshold} ms. ` +
924-
`Training transfer model using fitDataset() instead of fit()`);
925920
if (this.dataset.durationMillis() > datasetDurationMillisThreshold) {
921+
console.log(
922+
`Detected large dataset: total duration = ` +
923+
`${this.dataset.durationMillis()} ms > ` +
924+
`${datasetDurationMillisThreshold} ms. ` +
925+
`Training transfer model using fitDataset() instead of fit()`);
926926
return await this.trainOnDataset(config);
927927
} else {
928928
return await this.trainOnTensors(config);

speech-commands/src/browser_fft_recognizer_test.ts

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -880,8 +880,9 @@ describeWithFlags('Browser FFT recognizer', tf.test_util.NODE_ENVS, () => {
880880
// tslint:disable-next-line:no-any
881881
const transferHead = (transfer as any).transferHead as tf.Sequential;
882882
const numLayers = transferHead.layers.length;
883-
const oldTransferKernel =
884-
transferHead.getLayer(null, numLayers - 1).getWeights()[0].dataSync();
883+
const oldTransferWeightValues =
884+
transferHead.getLayer(null, numLayers - 1).getWeights()
885+
.map(weight => weight.dataSync());
885886

886887
const history =
887888
await transfer.train({optimizer: tf.train.sgd(1)}) as tf.History;
@@ -897,19 +898,21 @@ describeWithFlags('Browser FFT recognizer', tf.test_util.NODE_ENVS, () => {
897898

898899
// Verify that the weights of the dense layer in the base model doesn't
899900
// change, i.e., is frozen.
900-
const newTransferKernel =
901-
transferHead.getLayer(null, numLayers - 1).getWeights()[0].dataSync();
901+
const newTransferWeightValues =
902+
transferHead.getLayer(null, numLayers - 1).getWeights()
903+
.map(weight => weight.dataSync());
902904
baseModelOldWeightValues.forEach((oldWeight, i) => {
903905
tf.test_util.expectArraysClose(baseModelNewWeightValues[i], oldWeight);
904906
});
905907
// Verify that the weight of the transfer-learning head model changes
906908
// after training.
907-
expect(tf.tensor1d(newTransferKernel)
908-
.sub(tf.tensor1d(oldTransferKernel))
909-
.abs()
910-
.max()
911-
.dataSync()[0])
912-
.toBeGreaterThan(1e-3);
909+
const maxWeightChanges = newTransferWeightValues.map(
910+
(newValues, i) => tf.tensor1d(newValues)
911+
.sub(tf.tensor1d(oldTransferWeightValues[i]))
912+
.abs()
913+
.max()
914+
.dataSync()[0]);
915+
expect(Math.max(...maxWeightChanges)).toBeGreaterThan(1e-3);
913916

914917
// Test recognize() with the transfer recognizer.
915918
const spectrogram =

0 commit comments

Comments
 (0)