Skip to content

Commit 24ba3b6

Browse files
author
Aakash N S
committed
Initial working commit
0 parents  commit 24ba3b6

File tree

296 files changed

+38482
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

296 files changed

+38482
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
asr-demo/data/

asr-demo/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__pycache__
2+
*.pyc
3+
*.swp

asr-demo/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# asr-demo
2+
3+
We recommend that you use [conda](https://docs.conda.io/en/latest/miniconda.html) to install these dependencies.
4+
5+
What you need to run this demo
6+
- [python3](https://www.python.org/download/releases/3.0/)
7+
- [torchaudio](https://github.com/pytorch/audio/tree/master/torchaudio)
8+
- [pytorch](https://pytorch.org/)
9+
- [librosa](https://librosa.github.io/librosa/)
10+
- [fairseq](https://github.com/pytorch/fairseq) (clone the github repository)
11+
12+
13+
Models:
14+
- [dictionary](https://download.pytorch.org/models/audio/dict.txt)
15+
- [sentence piece model](https://download.pytorch.org/models/audio/spm.model)
16+
- [model](https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt)
17+
18+
Example command:
19+
Save the dictionary, sentence piece model and model in data
20+
21+
python interactive_asr.py ./data --max-tokens 10000000 --nbest 1 --path ./data/model.pt --beam 40 --task speech_recognition --user-dir ../fairseq/examples/speech_recognition

asr-demo/infer_file.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2017-present, Facebook, Inc.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the LICENSE file in
6+
# the root directory of this source tree. An additional grant of patent rights
7+
# can be found in the PATENTS file in the same directory.
8+
"""
9+
Run inference for pre-processed data with a trained model.
10+
"""
11+
12+
import logging
13+
import os
14+
import random
15+
import string
16+
import sys
17+
18+
import sentencepiece as spm
19+
import torch
20+
import torchaudio
21+
import numpy as np
22+
from fairseq import options, progress_bar, utils, tasks
23+
from fairseq.meters import StopwatchMeter, TimeMeter
24+
from fairseq.utils import import_user_module
25+
26+
27+
logger = logging.getLogger(__name__)
28+
logger.setLevel(logging.INFO)
29+
30+
31+
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
32+
33+
def add_asr_eval_argument(parser):
34+
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
35+
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
36+
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
37+
parser.add_argument(
38+
"--wfstlm", default=None, help="wfstlm on dictonary output units"
39+
)
40+
parser.add_argument(
41+
"--rnnt_decoding_type",
42+
default="greedy",
43+
help="wfstlm on dictonary\
44+
output units",
45+
)
46+
parser.add_argument(
47+
"--lm_weight",
48+
default=0.2,
49+
help="weight for wfstlm while interpolating\
50+
with neural score",
51+
)
52+
parser.add_argument(
53+
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
54+
)
55+
return parser
56+
57+
58+
def check_args(args):
59+
assert args.path is not None, "--path required for generation!"
60+
assert (
61+
not args.sampling or args.nbest == args.beam
62+
), "--sampling requires --nbest to be equal to --beam"
63+
assert (
64+
args.replace_unk is None or args.raw_text
65+
), "--replace-unk requires a raw text dataset (--raw-text)"
66+
67+
68+
def process_predictions(args, hypos, sp, tgt_dict):
69+
res = []
70+
for hypo in hypos[: min(len(hypos), args.nbest)]:
71+
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
72+
hyp_words = sp.DecodePieces(hyp_pieces.split())
73+
res.append(hyp_words)
74+
return res
75+
76+
77+
def optimize_models(args, models):
78+
"""Optimize ensemble for generation
79+
"""
80+
for model in models:
81+
model.make_generation_fast_(
82+
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
83+
need_attn=args.print_alignment,
84+
)
85+
if args.fp16:
86+
model.half()
87+
88+
model.to(dev)
89+
90+
91+
def calc_mean_invstddev(feature):
92+
if len(feature.shape) != 2:
93+
raise ValueError("We expect the input feature to be 2-D tensor")
94+
mean = np.mean(feature, axis=0)
95+
var = np.var(feature, axis=0)
96+
# avoid division by ~zero
97+
if var.any() < sys.float_info.epsilon:
98+
return mean, 1.0 / (np.sqrt(var) + sys.float_info.epsilon)
99+
return mean, 1.0 / np.sqrt(var)
100+
101+
102+
def calcMN(features):
103+
mean, invstddev = calc_mean_invstddev(features)
104+
res = (features - mean) * invstddev
105+
return res
106+
107+
import matplotlib.pyplot as plt
108+
109+
def transcribe(waveform, args, task, generator, models, sp, tgt_dict):
110+
r"""
111+
CUDA_VISIBLE_DEVICES=0 python infer_asr.py /Users/jamarshon/Documents/downloads/ \
112+
--task speech_recognition --max-tokens 10000000 --nbest 1 --path \
113+
/Users/jamarshon/Downloads/checkpoint_avg_60_80.pt --beam 20
114+
"""
115+
num_features = 80
116+
output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features)
117+
output_cmvn = calcMN(output.cpu().detach().numpy())
118+
119+
# size (m, n)
120+
source = torch.tensor(output_cmvn)
121+
source = source.to(dev)
122+
frames_lengths = torch.LongTensor([source.size(0)])
123+
124+
# size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...)
125+
source.unsqueeze_(0)
126+
sample = {'net_input': {'src_tokens': source, 'src_lengths': frames_lengths}}
127+
128+
hypos = task.inference_step(generator, models, sample)
129+
130+
assert len(hypos) == 1
131+
transcription = []
132+
print(hypos)
133+
for i in range(len(hypos)):
134+
# Process top predictions
135+
hyp_words = process_predictions(args, hypos[i], sp, tgt_dict)
136+
transcription.append(hyp_words)
137+
138+
print('transcription:', transcription)
139+
return transcription
140+
141+
def main(args):
142+
check_args(args)
143+
import_user_module(args)
144+
145+
if args.max_tokens is None and args.max_sentences is None:
146+
args.max_tokens = 30000
147+
logger.info(args)
148+
149+
#use_cuda = torch.cuda.is_available() and not args.cpu
150+
# use_cuda = False
151+
152+
# Load dataset splits
153+
task = tasks.setup_task(args)
154+
155+
# Set dictionary
156+
tgt_dict = task.target_dictionary
157+
158+
if args.ctc or args.rnnt:
159+
tgt_dict.add_symbol("<ctc_blank>")
160+
if args.ctc:
161+
logger.info("| decoding a ctc model")
162+
if args.rnnt:
163+
logger.info("| decoding a rnnt model")
164+
165+
# Load ensemble
166+
logger.info("| loading model(s) from {}".format(args.path))
167+
models, _model_args = utils.load_ensemble_for_inference(
168+
args.path.split(":"),
169+
task,
170+
model_arg_overrides=eval(args.model_overrides), # noqa
171+
)
172+
optimize_models(args, models)
173+
174+
# Initialize generator
175+
generator = task.build_generator(args)
176+
177+
sp = spm.SentencePieceProcessor()
178+
sp.Load(os.path.join(args.data, 'spm.model'))
179+
180+
# TODO: replace this
181+
# path = '/Users/jamarshon/Downloads/snippet.mp3'
182+
# path = '/Users/jamarshon/Downloads/hamlet.mp3'
183+
path = '/home/aakashns/speech_transcribe/deepspeech.pytorch/data/an4_dataset/train/an4/wav/cen8-mwhw-b.wav'
184+
if not os.path.exists(path):
185+
raise FileNotFoundError("Audio file not found: {}".format(path))
186+
waveform, sample_rate = torchaudio.load_wav(path)
187+
waveform = waveform.mean(0, True)
188+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,new_freq=16000)(waveform)
189+
# waveform = waveform[:, :16000*30]
190+
# torchaudio.save('/Users/jamarshon/Downloads/hello.wav', waveform >> 16, 16000)
191+
import time
192+
print(sample_rate, waveform.shape)
193+
start = time.time()
194+
transcribe(waveform, args, task, generator, models, sp, tgt_dict)
195+
end = time.time()
196+
print(end - start)
197+
198+
199+
def cli_main():
200+
parser = options.get_generation_parser()
201+
parser = add_asr_eval_argument(parser)
202+
#args = fairspeq_options.parse_args_and_arch(parser)
203+
args = options.parse_args_and_arch(parser)
204+
main(args)
205+
206+
207+
if __name__ == "__main__":
208+
cli_main()

0 commit comments

Comments
 (0)