diff --git a/neural_networks.py b/neural_networks.py index 40400c92..25d3e01b 100644 --- a/neural_networks.py +++ b/neural_networks.py @@ -657,6 +657,8 @@ def __init__(self, options,inp_dim): self._nr_of_filters = int(options['logmelfb_nr_filt']) self._stft_window_size = int(options['logmelfb_stft_window_size']) self._stft_window_shift = int(options['logmelfb_stft_window_shift']) + self._normalization_mode = options['logmelfb_normalization_mode'] + assert self._normalization_mode in ['batch', 'sequence', 'none'] self._use_cuda = strtobool(options['use_cuda']) self.out_dim = self._nr_of_filters self._mspec = torchaudio.transforms.MelSpectrogram( @@ -674,6 +676,18 @@ def _safe_log(inp, epsilon=1e-20): eps = eps.cuda() log_inp = torch.log10(torch.max(inp, eps.expand_as(inp))) return log_inp + def _normalize_features(data, normalization_mode, eps=1e-6): + if self._normalization_mode in ['none']: + return data + if self._normalization_mode in ['batch']: + norm_axis = (0, 1) + elif self._normalization_mode in ['sequence']: + norm_axis = (0) + mean = data.mean(norm_axis, keepdim=True) + std = data.std(norm_axis, keepdim=True) + out = (data - mean) / (std.clamp(min=eps)) + return out + assert x.shape[-1] == 1, 'Multi channel time signal processing not suppored yet' x_reshape_for_stft = torch.squeeze(x, -1).transpose(0, 1) if self._use_cuda: @@ -691,7 +705,8 @@ def _safe_log(inp, epsilon=1e-20): x_power_stft_reshape_for_filterbank_mult = x_power_stft.transpose(1, 2) mel_spec = self._mspec.fm(x_power_stft_reshape_for_filterbank_mult).transpose(0, 1) log_mel_spec = _safe_log(mel_spec) - out = log_mel_spec + norm_log_mel_spec = _normalize_features(log_mel_spec, self._normalization_mode) + out = norm_log_mel_spec return out class channel_averaging(nn.Module): diff --git a/proto/logMelFb.proto b/proto/logMelFb.proto index 45bdcb32..5c9ecd5e 100644 --- a/proto/logMelFb.proto +++ b/proto/logMelFb.proto @@ -2,5 +2,6 @@ logmelfb_nr_filt=int logmelfb_stft_window_size=int logmelfb_stft_window_shift=int +logmelfb_normalization_mode=str