|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) 2017-present, Facebook, Inc. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the license found in the LICENSE file in |
| 6 | +# the root directory of this source tree. An additional grant of patent rights |
| 7 | +# can be found in the PATENTS file in the same directory. |
| 8 | +""" |
| 9 | +Run inference for pre-processed data with a trained model. |
| 10 | +""" |
| 11 | + |
| 12 | +import logging |
| 13 | +import os |
| 14 | +import random |
| 15 | +import string |
| 16 | +import sys |
| 17 | + |
| 18 | +import sentencepiece as spm |
| 19 | +import torch |
| 20 | +import torchaudio |
| 21 | +import numpy as np |
| 22 | +from fairseq import options, progress_bar, utils, tasks |
| 23 | +from fairseq.meters import StopwatchMeter, TimeMeter |
| 24 | +from fairseq.utils import import_user_module |
| 25 | + |
| 26 | + |
| 27 | +logger = logging.getLogger(__name__) |
| 28 | +logger.setLevel(logging.INFO) |
| 29 | + |
| 30 | + |
| 31 | +dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 32 | + |
| 33 | +def add_asr_eval_argument(parser): |
| 34 | + parser.add_argument("--ctc", action="store_true", help="decode a ctc model") |
| 35 | + parser.add_argument("--rnnt", default=False, help="decode a rnnt model") |
| 36 | + parser.add_argument("--kspmodel", default=None, help="sentence piece model") |
| 37 | + parser.add_argument( |
| 38 | + "--wfstlm", default=None, help="wfstlm on dictonary output units" |
| 39 | + ) |
| 40 | + parser.add_argument( |
| 41 | + "--rnnt_decoding_type", |
| 42 | + default="greedy", |
| 43 | + help="wfstlm on dictonary\ |
| 44 | +output units", |
| 45 | + ) |
| 46 | + parser.add_argument( |
| 47 | + "--lm_weight", |
| 48 | + default=0.2, |
| 49 | + help="weight for wfstlm while interpolating\ |
| 50 | +with neural score", |
| 51 | + ) |
| 52 | + parser.add_argument( |
| 53 | + "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" |
| 54 | + ) |
| 55 | + return parser |
| 56 | + |
| 57 | + |
| 58 | +def check_args(args): |
| 59 | + assert args.path is not None, "--path required for generation!" |
| 60 | + assert ( |
| 61 | + not args.sampling or args.nbest == args.beam |
| 62 | + ), "--sampling requires --nbest to be equal to --beam" |
| 63 | + assert ( |
| 64 | + args.replace_unk is None or args.raw_text |
| 65 | + ), "--replace-unk requires a raw text dataset (--raw-text)" |
| 66 | + |
| 67 | + |
| 68 | +def process_predictions(args, hypos, sp, tgt_dict): |
| 69 | + res = [] |
| 70 | + for hypo in hypos[: min(len(hypos), args.nbest)]: |
| 71 | + hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) |
| 72 | + hyp_words = sp.DecodePieces(hyp_pieces.split()) |
| 73 | + res.append(hyp_words) |
| 74 | + return res |
| 75 | + |
| 76 | + |
| 77 | +def optimize_models(args, models): |
| 78 | + """Optimize ensemble for generation |
| 79 | + """ |
| 80 | + for model in models: |
| 81 | + model.make_generation_fast_( |
| 82 | + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, |
| 83 | + need_attn=args.print_alignment, |
| 84 | + ) |
| 85 | + if args.fp16: |
| 86 | + model.half() |
| 87 | + |
| 88 | + model.to(dev) |
| 89 | + |
| 90 | + |
| 91 | +def calc_mean_invstddev(feature): |
| 92 | + if len(feature.shape) != 2: |
| 93 | + raise ValueError("We expect the input feature to be 2-D tensor") |
| 94 | + mean = np.mean(feature, axis=0) |
| 95 | + var = np.var(feature, axis=0) |
| 96 | + # avoid division by ~zero |
| 97 | + if var.any() < sys.float_info.epsilon: |
| 98 | + return mean, 1.0 / (np.sqrt(var) + sys.float_info.epsilon) |
| 99 | + return mean, 1.0 / np.sqrt(var) |
| 100 | + |
| 101 | + |
| 102 | +def calcMN(features): |
| 103 | + mean, invstddev = calc_mean_invstddev(features) |
| 104 | + res = (features - mean) * invstddev |
| 105 | + return res |
| 106 | + |
| 107 | +import matplotlib.pyplot as plt |
| 108 | + |
| 109 | +def transcribe(waveform, args, task, generator, models, sp, tgt_dict): |
| 110 | + r""" |
| 111 | + CUDA_VISIBLE_DEVICES=0 python infer_asr.py /Users/jamarshon/Documents/downloads/ \ |
| 112 | + --task speech_recognition --max-tokens 10000000 --nbest 1 --path \ |
| 113 | + /Users/jamarshon/Downloads/checkpoint_avg_60_80.pt --beam 20 |
| 114 | + """ |
| 115 | + num_features = 80 |
| 116 | + output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features) |
| 117 | + output_cmvn = calcMN(output.cpu().detach().numpy()) |
| 118 | + |
| 119 | + # size (m, n) |
| 120 | + source = torch.tensor(output_cmvn) |
| 121 | + source = source.to(dev) |
| 122 | + frames_lengths = torch.LongTensor([source.size(0)]) |
| 123 | + |
| 124 | + # size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...) |
| 125 | + source.unsqueeze_(0) |
| 126 | + sample = {'net_input': {'src_tokens': source, 'src_lengths': frames_lengths}} |
| 127 | + |
| 128 | + hypos = task.inference_step(generator, models, sample) |
| 129 | + |
| 130 | + assert len(hypos) == 1 |
| 131 | + transcription = [] |
| 132 | + print(hypos) |
| 133 | + for i in range(len(hypos)): |
| 134 | + # Process top predictions |
| 135 | + hyp_words = process_predictions(args, hypos[i], sp, tgt_dict) |
| 136 | + transcription.append(hyp_words) |
| 137 | + |
| 138 | + print('transcription:', transcription) |
| 139 | + return transcription |
| 140 | + |
| 141 | +def main(args): |
| 142 | + check_args(args) |
| 143 | + import_user_module(args) |
| 144 | + |
| 145 | + if args.max_tokens is None and args.max_sentences is None: |
| 146 | + args.max_tokens = 30000 |
| 147 | + logger.info(args) |
| 148 | + |
| 149 | + #use_cuda = torch.cuda.is_available() and not args.cpu |
| 150 | + # use_cuda = False |
| 151 | + |
| 152 | + # Load dataset splits |
| 153 | + task = tasks.setup_task(args) |
| 154 | + |
| 155 | + # Set dictionary |
| 156 | + tgt_dict = task.target_dictionary |
| 157 | + |
| 158 | + if args.ctc or args.rnnt: |
| 159 | + tgt_dict.add_symbol("<ctc_blank>") |
| 160 | + if args.ctc: |
| 161 | + logger.info("| decoding a ctc model") |
| 162 | + if args.rnnt: |
| 163 | + logger.info("| decoding a rnnt model") |
| 164 | + |
| 165 | + # Load ensemble |
| 166 | + logger.info("| loading model(s) from {}".format(args.path)) |
| 167 | + models, _model_args = utils.load_ensemble_for_inference( |
| 168 | + args.path.split(":"), |
| 169 | + task, |
| 170 | + model_arg_overrides=eval(args.model_overrides), # noqa |
| 171 | + ) |
| 172 | + optimize_models(args, models) |
| 173 | + |
| 174 | + # Initialize generator |
| 175 | + generator = task.build_generator(args) |
| 176 | + |
| 177 | + sp = spm.SentencePieceProcessor() |
| 178 | + sp.Load(os.path.join(args.data, 'spm.model')) |
| 179 | + |
| 180 | + # TODO: replace this |
| 181 | + # path = '/Users/jamarshon/Downloads/snippet.mp3' |
| 182 | + # path = '/Users/jamarshon/Downloads/hamlet.mp3' |
| 183 | + path = '/home/aakashns/speech_transcribe/deepspeech.pytorch/data/an4_dataset/train/an4/wav/cen8-mwhw-b.wav' |
| 184 | + if not os.path.exists(path): |
| 185 | + raise FileNotFoundError("Audio file not found: {}".format(path)) |
| 186 | + waveform, sample_rate = torchaudio.load_wav(path) |
| 187 | + waveform = waveform.mean(0, True) |
| 188 | + waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,new_freq=16000)(waveform) |
| 189 | + # waveform = waveform[:, :16000*30] |
| 190 | + # torchaudio.save('/Users/jamarshon/Downloads/hello.wav', waveform >> 16, 16000) |
| 191 | + import time |
| 192 | + print(sample_rate, waveform.shape) |
| 193 | + start = time.time() |
| 194 | + transcribe(waveform, args, task, generator, models, sp, tgt_dict) |
| 195 | + end = time.time() |
| 196 | + print(end - start) |
| 197 | + |
| 198 | + |
| 199 | +def cli_main(): |
| 200 | + parser = options.get_generation_parser() |
| 201 | + parser = add_asr_eval_argument(parser) |
| 202 | + #args = fairspeq_options.parse_args_and_arch(parser) |
| 203 | + args = options.parse_args_and_arch(parser) |
| 204 | + main(args) |
| 205 | + |
| 206 | + |
| 207 | +if __name__ == "__main__": |
| 208 | + cli_main() |
0 commit comments