16
16
*/
17
17
18
18
import * as tf from '@tensorflow/tfjs' ;
19
+
19
20
import { BrowserFftFeatureExtractor , SpectrogramCallback } from './browser_fft_extractor' ;
20
21
import { loadMetadataJson , normalize , normalizeFloat32Array } from './browser_fft_utils' ;
21
22
import { BACKGROUND_NOISE_TAG , Dataset } from './dataset' ;
22
23
import { concatenateFloat32Arrays } from './generic_utils' ;
23
24
import { balancedTrainValSplit } from './training_utils' ;
24
- import { EvaluateConfig , EvaluateResult , Example , ExampleCollectionOptions , RecognizeConfig , RecognizerCallback , RecognizerParams , ROCCurve , SpectrogramData , SpeechCommandRecognizer , SpeechCommandRecognizerMetadata , SpeechCommandRecognizerResult , StreamingRecognitionConfig , TransferLearnConfig , TransferSpeechCommandRecognizer , AudioDataAugmentationOptions } from './types' ;
25
+ import { AudioDataAugmentationOptions , EvaluateConfig , EvaluateResult , Example , ExampleCollectionOptions , RecognizeConfig , RecognizerCallback , RecognizerParams , ROCCurve , SpectrogramData , SpeechCommandRecognizer , SpeechCommandRecognizerMetadata , SpeechCommandRecognizerResult , StreamingRecognitionConfig , TransferLearnConfig , TransferSpeechCommandRecognizer } from './types' ;
25
26
import { version } from './version' ;
26
27
27
28
export const UNKNOWN_TAG = '_unknown_' ;
@@ -206,7 +207,8 @@ export class BrowserFftSpeechCommandRecognizer implements
206
207
( ) => `Expected overlapFactor to be >= 0 and < 1, but got ${
207
208
overlapFactor } `) ;
208
209
209
- const spectrogramCallback : SpectrogramCallback = async ( x : tf . Tensor ) => {
210
+ const spectrogramCallback : SpectrogramCallback =
211
+ async ( x : tf . Tensor , timeData ?: tf . Tensor ) => {
210
212
const normalizedX = normalize ( x ) ;
211
213
let y : tf . Tensor ;
212
214
let embedding : tf . Tensor ;
@@ -714,27 +716,33 @@ class TransferBrowserFftSpeechCommandRecognizer extends
714
716
let lastIndex = - 1 ;
715
717
const spectrogramSnippets : Float32Array [ ] = [ ] ;
716
718
717
- const spectrogramCallback : SpectrogramCallback = async ( x : tf . Tensor ) => {
719
+ const spectrogramCallback : SpectrogramCallback =
720
+ async ( freqData : tf . Tensor , timeData ?: tf . Tensor ) => {
718
721
// TODO(cais): can we consolidate the logic in the two branches?
719
722
if ( options . onSnippet == null ) {
720
- const normalizedX = normalize ( x ) ;
723
+ const normalizedX = normalize ( freqData ) ;
721
724
this . dataset . addExample ( {
722
725
label : word ,
723
726
spectrogram : {
724
727
data : await normalizedX . data ( ) as Float32Array ,
725
728
frameSize : this . nonBatchInputShape [ 1 ] ,
726
- }
729
+ } ,
730
+ rawAudio : options . includeRawAudio ? {
731
+ data : await timeData . data ( ) as Float32Array ,
732
+ sampleRateHz : this . audioDataExtractor . sampleRateHz
733
+ } :
734
+ undefined
727
735
} ) ;
728
736
normalizedX . dispose ( ) ;
729
737
await this . audioDataExtractor . stop ( ) ;
730
738
this . streaming = false ;
731
739
this . collateTransferWords ( ) ;
732
740
resolve ( {
733
- data : await x . data ( ) as Float32Array ,
741
+ data : await freqData . data ( ) as Float32Array ,
734
742
frameSize : this . nonBatchInputShape [ 1 ] ,
735
743
} ) ;
736
744
} else {
737
- const data = await x . data ( ) as Float32Array ;
745
+ const data = await freqData . data ( ) as Float32Array ;
738
746
if ( lastIndex === - 1 ) {
739
747
lastIndex = data . length ;
740
748
}
@@ -763,8 +771,15 @@ class TransferBrowserFftSpeechCommandRecognizer extends
763
771
data : normalized ,
764
772
frameSize : this . nonBatchInputShape [ 1 ]
765
773
} ;
766
- this . dataset . addExample (
767
- { label : word , spectrogram : finalSpectrogram } ) ;
774
+ this . dataset . addExample ( {
775
+ label : word ,
776
+ spectrogram : finalSpectrogram ,
777
+ rawAudio : options . includeRawAudio ? {
778
+ data : await timeData . data ( ) as Float32Array ,
779
+ sampleRateHz : this . audioDataExtractor . sampleRateHz
780
+ } :
781
+ undefined
782
+ } ) ;
768
783
// TODO(cais): Fix 1-tensor memory leak.
769
784
resolve ( finalSpectrogram ) ;
770
785
}
@@ -777,7 +792,8 @@ class TransferBrowserFftSpeechCommandRecognizer extends
777
792
columnTruncateLength : this . nonBatchInputShape [ 1 ] ,
778
793
suppressionTimeMillis : 0 ,
779
794
spectrogramCallback,
780
- overlapFactor
795
+ overlapFactor,
796
+ includeRawAudio : options . includeRawAudio
781
797
} ) ;
782
798
this . audioDataExtractor . start ( options . audioTrackConstraints ) ;
783
799
} ) ;
@@ -910,11 +926,9 @@ class TransferBrowserFftSpeechCommandRecognizer extends
910
926
const numFrames = this . nonBatchInputShape [ 0 ] ;
911
927
windowHopRatio = windowHopRatio || DEFAULT_WINDOW_HOP_RATIO ;
912
928
const hopFrames = Math . round ( windowHopRatio * numFrames ) ;
913
- const out = this . dataset . getData ( null , {
914
- numFrames,
915
- hopFrames,
916
- ...augmentationOptions
917
- } ) as { xs : tf . Tensor4D , ys ?: tf . Tensor2D } ;
929
+ const out = this . dataset . getData (
930
+ null , { numFrames, hopFrames, ...augmentationOptions } ) as
931
+ { xs : tf . Tensor4D , ys ?: tf . Tensor2D } ;
918
932
return { xs : out . xs , ys : out . ys as tf . Tensor } ;
919
933
}
920
934
@@ -936,8 +950,8 @@ class TransferBrowserFftSpeechCommandRecognizer extends
936
950
* `this.model.fitDataset`.
937
951
*/
938
952
private collectTransferDataAsTfDataset (
939
- windowHopRatio ?: number , validationSplit = 0.15 ,
940
- batchSize = 32 , augmentationOptions ?: AudioDataAugmentationOptions ) :
953
+ windowHopRatio ?: number , validationSplit = 0.15 , batchSize = 32 ,
954
+ augmentationOptions ?: AudioDataAugmentationOptions ) :
941
955
[ tf . data . Dataset < { } > , tf . data . Dataset < { } > ] {
942
956
const numFrames = this . nonBatchInputShape [ 0 ] ;
943
957
windowHopRatio = windowHopRatio || DEFAULT_WINDOW_HOP_RATIO ;
@@ -1037,9 +1051,8 @@ class TransferBrowserFftSpeechCommandRecognizer extends
1037
1051
const batchSize = config . batchSize == null ? 32 : config . batchSize ;
1038
1052
const windowHopRatio = config . windowHopRatio || DEFAULT_WINDOW_HOP_RATIO ;
1039
1053
const [ trainDataset , valDataset ] = this . collectTransferDataAsTfDataset (
1040
- windowHopRatio , config . validationSplit , batchSize , {
1041
- augmentByMixingNoiseRatio : config . augmentByMixingNoiseRatio
1042
- } ) ;
1054
+ windowHopRatio , config . validationSplit , batchSize ,
1055
+ { augmentByMixingNoiseRatio : config . augmentByMixingNoiseRatio } ) ;
1043
1056
const t0 = tf . util . now ( ) ;
1044
1057
const history = await this . model . fitDataset ( trainDataset , {
1045
1058
epochs : config . epochs ,
@@ -1067,9 +1080,9 @@ class TransferBrowserFftSpeechCommandRecognizer extends
1067
1080
Promise < tf . History | [ tf . History , tf . History ] > {
1068
1081
// Prepare the data.
1069
1082
const windowHopRatio = config . windowHopRatio || DEFAULT_WINDOW_HOP_RATIO ;
1070
- const { xs, ys} = this . collectTransferDataAsTensors ( windowHopRatio , {
1071
- augmentByMixingNoiseRatio : config . augmentByMixingNoiseRatio
1072
- } ) ;
1083
+ const { xs, ys} = this . collectTransferDataAsTensors (
1084
+ windowHopRatio ,
1085
+ { augmentByMixingNoiseRatio : config . augmentByMixingNoiseRatio } ) ;
1073
1086
console . log (
1074
1087
`Training data: xs.shape = ${ xs . shape } , ys.shape = ${ ys . shape } ` ) ;
1075
1088
0 commit comments