Skip to content

Commit a012718

Browse files
authored
[speech-commands] Add option to collect raw audio waveform during collectExample() (tensorflow#195)
... in addition to spectrograms. - The option argument to `collectExample()` now has a new field called "includeTimeDomainWaveform". If set to true, the collected example will include the raw audio waveform. Obviously, use caution when using this feature as it'll increase the memory consumption of the examples by a few times. - The demo page now has a new checkbox called "Include audio waveform" that demonstrates this new feature, along with the playback of the audio waveforms through WebAudio. - Add `speechCommands.utils.playRawAudio()`.
1 parent cee11b3 commit a012718

12 files changed

+211
-37
lines changed

speech-commands/demo/dataset-vis.js

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@ export class DatasetViz {
113113
* @param {SpectrogramData} spectrogram Optional spectrogram data.
114114
* If provided, will use it as is. If not provided, will use WebAudio
115115
* to collect an example.
116+
* @param {RawAudio} rawAudio Raw audio waveform. Optional
116117
* @param {string} uid UID of the example being drawn. Must match the UID
117118
* of the example from `this.transferRecognizer`.
118119
*/
119-
async drawExample(wordDiv, word, spectrogram, uid) {
120+
async drawExample(wordDiv, word, spectrogram, rawAudio, uid) {
120121
if (uid == null) {
121122
throw new Error('Error: UID is not provided for pre-existing example.');
122123
}
@@ -192,6 +193,17 @@ export class DatasetViz {
192193
keyFrameIndex: spectrogram.keyFrameIndex
193194
});
194195

196+
if (rawAudio != null) {
197+
const playButton = document.createElement('button');
198+
playButton.textContent = '▶️';
199+
playButton.addEventListener('click', () => {
200+
playButton.disabled = true;
201+
speechCommands.utils.playRawAudio(
202+
rawAudio, () => playButton.disabled = false);
203+
});
204+
wordDiv.appendChild(playButton);
205+
}
206+
195207
// Create Delete button.
196208
const deleteButton = document.createElement('button');
197209
deleteButton.textContent = 'X';
@@ -250,7 +262,8 @@ export class DatasetViz {
250262
}
251263

252264
const spectrogram = example.example.spectrogram;
253-
await this.drawExample(wordDiv, word, spectrogram, example.uid);
265+
await this.drawExample(
266+
wordDiv, word, spectrogram, example.example.rawAudio, example.uid);
254267
} else {
255268
removeNonFixedChildrenFromWordDiv(wordDiv);
256269
}

speech-commands/demo/index.html

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
<option value="1">Duration x1</option>
3232
<option value="2" selected="true">Duration x2</option>
3333
</select>
34+
35+
<input type="checkbox" id="include-audio-waveform">
36+
<span id="include-audio-waveform-label">Include audio waveform</span>
37+
3438
<button id="enter-learn-words" disabled="true">Enter transfer words</button>
3539

3640
<div id="transfer-learn-history"></div>

speech-commands/demo/index.js

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ const transferModelNameInput = document.getElementById('transfer-model-name');
5656
const learnWordsInput = document.getElementById('learn-words');
5757
const durationMultiplierSelect = document.getElementById('duration-multiplier');
5858
const enterLearnWordsButton = document.getElementById('enter-learn-words');
59+
const includeTimeDomainWaveformCheckbox =
60+
document.getElementById('include-audio-waveform');
5961
const collectButtonsDiv = document.getElementById('collect-words');
6062
const startTransferLearnButton =
6163
document.getElementById('start-transfer-learn');
@@ -278,18 +280,22 @@ function createWordDivs(transferWords) {
278280
}
279281
}
280282

283+
collectExampleOptions.includeRawAudio =
284+
includeTimeDomainWaveformCheckbox.checked;
281285
const spectrogram = await transferRecognizer.collectExample(
282286
word, collectExampleOptions);
283287

288+
284289
if (intervalJob != null) {
285290
clearInterval(intervalJob);
286291
}
287292
if (progressBar != null) {
288293
wordDiv.removeChild(progressBar);
289294
}
290295
const examples = transferRecognizer.getExamples(word)
291-
const exampleUID = examples[examples.length - 1].uid;
292-
await datasetViz.drawExample(wordDiv, word, spectrogram, exampleUID);
296+
const example = examples[examples.length - 1];
297+
await datasetViz.drawExample(
298+
wordDiv, word, spectrogram, example.example.rawAudio, example.uid);
293299
enableAllCollectWordButtons();
294300
});
295301
}

speech-commands/demo/style.css

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,11 @@ textarea {
173173
input[type=checkbox] {
174174
transform: scale(2);
175175
}
176+
177+
#include-audio-waveform {
178+
margin-left: 20px;
179+
}
180+
181+
#include-audio-waveform-label {
182+
font-size: 17px;
183+
}

speech-commands/src/browser_fft_extractor.ts

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
*/
2121

2222
import * as tf from '@tensorflow/tfjs';
23+
2324
import {getAudioContextConstructor, getAudioMediaStream} from './browser_fft_utils';
2425
import {FeatureExtractor, RecognizerParams} from './types';
2526

26-
export type SpectrogramCallback = (x: tf.Tensor) => Promise<boolean>;
27+
export type SpectrogramCallback = (freqData: tf.Tensor, timeData?: tf.Tensor) =>
28+
Promise<boolean>;
2729

2830
/**
2931
* Configurations for constructing BrowserFftFeatureExtractor.
@@ -68,6 +70,14 @@ export interface BrowserFftFeatureExtractorConfig extends RecognizerParams {
6870
* will be taken every 600 ms.
6971
*/
7072
overlapFactor: number;
73+
74+
/**
75+
* Whether to collect the raw time-domain audio waveform in addition to the
76+
* spectrogram.
77+
*
78+
* Default: `false`.
79+
*/
80+
includeRawAudio?: boolean;
7181
}
7282

7383
/**
@@ -91,6 +101,7 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
91101
// Overlapping factor: the ratio between the temporal spacing between
92102
// consecutive spectrograms and the length of each individual spectrogram.
93103
readonly overlapFactor: number;
104+
readonly includeRawAudio: boolean;
94105

95106
private readonly spectrogramCallback: SpectrogramCallback;
96107

@@ -101,7 +112,9 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
101112
private analyser: AnalyserNode;
102113
private tracker: Tracker;
103114
private freqData: Float32Array;
115+
private timeData: Float32Array;
104116
private freqDataQueue: Float32Array[];
117+
private timeDataQueue: Float32Array[];
105118
// tslint:disable-next-line:no-any
106119
private frameIntervalTask: any;
107120
private frameDurationMillis: number;
@@ -144,6 +157,7 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
144157
this.frameDurationMillis = this.fftSize / this.sampleRateHz * 1e3;
145158
this.columnTruncateLength = config.columnTruncateLength || this.fftSize;
146159
this.overlapFactor = config.overlapFactor;
160+
this.includeRawAudio = config.includeRawAudio;
147161

148162
tf.util.assert(
149163
this.overlapFactor >= 0 && this.overlapFactor < 1,
@@ -183,6 +197,10 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
183197
// Reset the queue.
184198
this.freqDataQueue = [];
185199
this.freqData = new Float32Array(this.fftSize);
200+
if (this.includeRawAudio) {
201+
this.timeDataQueue = [];
202+
this.timeData = new Float32Array(this.fftSize);
203+
}
186204
const period =
187205
Math.max(1, Math.round(this.numFrames * (1 - this.overlapFactor)));
188206
this.tracker = new Tracker(
@@ -199,20 +217,31 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
199217
}
200218

201219
this.freqDataQueue.push(this.freqData.slice(0, this.columnTruncateLength));
220+
if (this.includeRawAudio) {
221+
this.analyser.getFloatTimeDomainData(this.timeData);
222+
this.timeDataQueue.push(this.timeData.slice());
223+
}
202224
if (this.freqDataQueue.length > this.numFrames) {
203225
// Drop the oldest frame (least recent).
204226
this.freqDataQueue.shift();
205227
}
206228
const shouldFire = this.tracker.tick();
207229
if (shouldFire) {
208230
const freqData = flattenQueue(this.freqDataQueue);
209-
const inputTensor = getInputTensorFromFrequencyData(
231+
const freqDataTensor = getInputTensorFromFrequencyData(
210232
freqData, [1, this.numFrames, this.columnTruncateLength, 1]);
211-
const shouldRest = await this.spectrogramCallback(inputTensor);
233+
let timeDataTensor: tf.Tensor;
234+
if (this.includeRawAudio) {
235+
const timeData = flattenQueue(this.timeDataQueue);
236+
timeDataTensor = getInputTensorFromFrequencyData(
237+
timeData, [1, this.numFrames * this.fftSize]);
238+
}
239+
const shouldRest =
240+
await this.spectrogramCallback(freqDataTensor, timeDataTensor);
212241
if (shouldRest) {
213242
this.tracker.suppress();
214243
}
215-
inputTensor.dispose();
244+
tf.dispose([freqDataTensor, timeDataTensor]);
216245
}
217246
}
218247

speech-commands/src/browser_fft_recognizer.ts

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
*/
1717

1818
import * as tf from '@tensorflow/tfjs';
19+
1920
import {BrowserFftFeatureExtractor, SpectrogramCallback} from './browser_fft_extractor';
2021
import {loadMetadataJson, normalize, normalizeFloat32Array} from './browser_fft_utils';
2122
import {BACKGROUND_NOISE_TAG, Dataset} from './dataset';
2223
import {concatenateFloat32Arrays} from './generic_utils';
2324
import {balancedTrainValSplit} from './training_utils';
24-
import {EvaluateConfig, EvaluateResult, Example, ExampleCollectionOptions, RecognizeConfig, RecognizerCallback, RecognizerParams, ROCCurve, SpectrogramData, SpeechCommandRecognizer, SpeechCommandRecognizerMetadata, SpeechCommandRecognizerResult, StreamingRecognitionConfig, TransferLearnConfig, TransferSpeechCommandRecognizer, AudioDataAugmentationOptions} from './types';
25+
import {AudioDataAugmentationOptions, EvaluateConfig, EvaluateResult, Example, ExampleCollectionOptions, RecognizeConfig, RecognizerCallback, RecognizerParams, ROCCurve, SpectrogramData, SpeechCommandRecognizer, SpeechCommandRecognizerMetadata, SpeechCommandRecognizerResult, StreamingRecognitionConfig, TransferLearnConfig, TransferSpeechCommandRecognizer} from './types';
2526
import {version} from './version';
2627

2728
export const UNKNOWN_TAG = '_unknown_';
@@ -206,7 +207,8 @@ export class BrowserFftSpeechCommandRecognizer implements
206207
() => `Expected overlapFactor to be >= 0 and < 1, but got ${
207208
overlapFactor}`);
208209

209-
const spectrogramCallback: SpectrogramCallback = async (x: tf.Tensor) => {
210+
const spectrogramCallback: SpectrogramCallback =
211+
async (x: tf.Tensor, timeData?: tf.Tensor) => {
210212
const normalizedX = normalize(x);
211213
let y: tf.Tensor;
212214
let embedding: tf.Tensor;
@@ -714,27 +716,33 @@ class TransferBrowserFftSpeechCommandRecognizer extends
714716
let lastIndex = -1;
715717
const spectrogramSnippets: Float32Array[] = [];
716718

717-
const spectrogramCallback: SpectrogramCallback = async (x: tf.Tensor) => {
719+
const spectrogramCallback: SpectrogramCallback =
720+
async (freqData: tf.Tensor, timeData?: tf.Tensor) => {
718721
// TODO(cais): can we consolidate the logic in the two branches?
719722
if (options.onSnippet == null) {
720-
const normalizedX = normalize(x);
723+
const normalizedX = normalize(freqData);
721724
this.dataset.addExample({
722725
label: word,
723726
spectrogram: {
724727
data: await normalizedX.data() as Float32Array,
725728
frameSize: this.nonBatchInputShape[1],
726-
}
729+
},
730+
rawAudio: options.includeRawAudio ? {
731+
data: await timeData.data() as Float32Array,
732+
sampleRateHz: this.audioDataExtractor.sampleRateHz
733+
} :
734+
undefined
727735
});
728736
normalizedX.dispose();
729737
await this.audioDataExtractor.stop();
730738
this.streaming = false;
731739
this.collateTransferWords();
732740
resolve({
733-
data: await x.data() as Float32Array,
741+
data: await freqData.data() as Float32Array,
734742
frameSize: this.nonBatchInputShape[1],
735743
});
736744
} else {
737-
const data = await x.data() as Float32Array;
745+
const data = await freqData.data() as Float32Array;
738746
if (lastIndex === -1) {
739747
lastIndex = data.length;
740748
}
@@ -763,8 +771,15 @@ class TransferBrowserFftSpeechCommandRecognizer extends
763771
data: normalized,
764772
frameSize: this.nonBatchInputShape[1]
765773
};
766-
this.dataset.addExample(
767-
{label: word, spectrogram: finalSpectrogram});
774+
this.dataset.addExample({
775+
label: word,
776+
spectrogram: finalSpectrogram,
777+
rawAudio: options.includeRawAudio ? {
778+
data: await timeData.data() as Float32Array,
779+
sampleRateHz: this.audioDataExtractor.sampleRateHz
780+
} :
781+
undefined
782+
});
768783
// TODO(cais): Fix 1-tensor memory leak.
769784
resolve(finalSpectrogram);
770785
}
@@ -777,7 +792,8 @@ class TransferBrowserFftSpeechCommandRecognizer extends
777792
columnTruncateLength: this.nonBatchInputShape[1],
778793
suppressionTimeMillis: 0,
779794
spectrogramCallback,
780-
overlapFactor
795+
overlapFactor,
796+
includeRawAudio: options.includeRawAudio
781797
});
782798
this.audioDataExtractor.start(options.audioTrackConstraints);
783799
});
@@ -910,11 +926,9 @@ class TransferBrowserFftSpeechCommandRecognizer extends
910926
const numFrames = this.nonBatchInputShape[0];
911927
windowHopRatio = windowHopRatio || DEFAULT_WINDOW_HOP_RATIO;
912928
const hopFrames = Math.round(windowHopRatio * numFrames);
913-
const out = this.dataset.getData(null, {
914-
numFrames,
915-
hopFrames,
916-
...augmentationOptions
917-
}) as {xs: tf.Tensor4D, ys?: tf.Tensor2D};
929+
const out = this.dataset.getData(
930+
null, {numFrames, hopFrames, ...augmentationOptions}) as
931+
{xs: tf.Tensor4D, ys?: tf.Tensor2D};
918932
return {xs: out.xs, ys: out.ys as tf.Tensor};
919933
}
920934

@@ -936,8 +950,8 @@ class TransferBrowserFftSpeechCommandRecognizer extends
936950
* `this.model.fitDataset`.
937951
*/
938952
private collectTransferDataAsTfDataset(
939-
windowHopRatio?: number, validationSplit = 0.15,
940-
batchSize = 32, augmentationOptions?: AudioDataAugmentationOptions):
953+
windowHopRatio?: number, validationSplit = 0.15, batchSize = 32,
954+
augmentationOptions?: AudioDataAugmentationOptions):
941955
[tf.data.Dataset<{}>, tf.data.Dataset<{}>] {
942956
const numFrames = this.nonBatchInputShape[0];
943957
windowHopRatio = windowHopRatio || DEFAULT_WINDOW_HOP_RATIO;
@@ -1037,9 +1051,8 @@ class TransferBrowserFftSpeechCommandRecognizer extends
10371051
const batchSize = config.batchSize == null ? 32 : config.batchSize;
10381052
const windowHopRatio = config.windowHopRatio || DEFAULT_WINDOW_HOP_RATIO;
10391053
const [trainDataset, valDataset] = this.collectTransferDataAsTfDataset(
1040-
windowHopRatio, config.validationSplit, batchSize, {
1041-
augmentByMixingNoiseRatio: config.augmentByMixingNoiseRatio
1042-
});
1054+
windowHopRatio, config.validationSplit, batchSize,
1055+
{augmentByMixingNoiseRatio: config.augmentByMixingNoiseRatio});
10431056
const t0 = tf.util.now();
10441057
const history = await this.model.fitDataset(trainDataset, {
10451058
epochs: config.epochs,
@@ -1067,9 +1080,9 @@ class TransferBrowserFftSpeechCommandRecognizer extends
10671080
Promise<tf.History|[tf.History, tf.History]> {
10681081
// Prepare the data.
10691082
const windowHopRatio = config.windowHopRatio || DEFAULT_WINDOW_HOP_RATIO;
1070-
const {xs, ys} = this.collectTransferDataAsTensors(windowHopRatio, {
1071-
augmentByMixingNoiseRatio: config.augmentByMixingNoiseRatio
1072-
});
1083+
const {xs, ys} = this.collectTransferDataAsTensors(
1084+
windowHopRatio,
1085+
{augmentByMixingNoiseRatio: config.augmentByMixingNoiseRatio});
10731086
console.log(
10741087
`Training data: xs.shape = ${xs.shape}, ys.shape = ${ys.shape}`);
10751088

0 commit comments

Comments
 (0)