Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix asr warnings #10469

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def perform_streaming(
# would pass the whole audio at once through the model like offline mode in order to compare the results with the stremaing mode
# the output of the model in the offline and streaming mode should be exactly the same
with torch.inference_mode():
with autocast():
with autocast:
processed_signal, processed_signal_length = streaming_buffer.get_all_audios()
with torch.no_grad():
(
Expand Down Expand Up @@ -156,7 +156,7 @@ def perform_streaming(
pred_out_stream = None
for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
with torch.inference_mode():
with autocast():
with autocast:
# keep_all_outputs needs to be True for the last step of streaming when model is trained with att_context_style=regular
# otherwise the last outputs would get dropped

Expand Down Expand Up @@ -313,19 +313,7 @@ def main():
raise ValueError("Model does not support multiple lookaheads.")

global autocast
if (
args.use_amp
and torch.cuda.is_available()
and hasattr(torch.cuda, 'amp')
and hasattr(torch.cuda.amp, 'autocast')
):
logging.info("AMP enabled!\n")
autocast = torch.cuda.amp.autocast
else:

@contextlib.contextmanager
def autocast():
yield
autocast = torch.amp.autocast(asr_model.device.type, enabled=args.use_amp)

# configure the decoding config
decoding_cfg = asr_model.cfg.decoding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
# Disable config overwriting
OmegaConf.set_struct(model_cfg.preprocessor, True)

# setup AMP (optional)
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
logging.info("AMP enabled!\n")
autocast = torch.cuda.amp.autocast
else:

@contextlib.contextmanager
def autocast(*args, **kwargs):
yield

# Compute output filename
cfg = compute_output_filename(cfg, model_name)

Expand Down Expand Up @@ -208,7 +198,7 @@ def autocast(*args, **kwargs):

amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16

with autocast(dtype=amp_dtype):
with torch.amp.autocast(asr_model.device.type, enabled=cfg.amp, dtype=amp_dtype):
with torch.no_grad():
hyps = get_buffered_pred_feat_multitaskAED(
frame_asr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class TranscriptionConfig:
# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models.
model_stride: int = (
8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models.
)

# Decoding strategy for CTC models
decoding: CTCDecodingConfig = CTCDecodingConfig()
Expand Down Expand Up @@ -163,16 +165,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
# Disable config overwriting
OmegaConf.set_struct(model_cfg.preprocessor, True)

# setup AMP (optional)
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
logging.info("AMP enabled!\n")
autocast = torch.cuda.amp.autocast
else:

@contextlib.contextmanager
def autocast():
yield

# Compute output filename
cfg = compute_output_filename(cfg, model_name)

Expand Down Expand Up @@ -214,20 +206,24 @@ def autocast():
logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")

frame_asr = FrameBatchASR(
asr_model=asr_model, frame_len=chunk_len, total_buffer=cfg.total_buffer_in_secs, batch_size=cfg.batch_size,
asr_model=asr_model,
frame_len=chunk_len,
total_buffer=cfg.total_buffer_in_secs,
batch_size=cfg.batch_size,
)

hyps = get_buffered_pred_feat(
frame_asr,
chunk_len,
tokens_per_chunk,
mid_delay,
model_cfg.preprocessor,
model_stride_in_secs,
asr_model.device,
manifest,
filepaths,
)
with torch.amp.autocast(asr_model.device.type, enabled=cfg.amp):
hyps = get_buffered_pred_feat(
frame_asr,
chunk_len,
tokens_per_chunk,
mid_delay,
model_cfg.preprocessor,
model_stride_in_secs,
asr_model.device,
manifest,
filepaths,
)
output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging

can_gpu = torch.cuda.is_available()


@dataclass
class TranscriptionConfig:
Expand All @@ -112,7 +110,9 @@ class TranscriptionConfig:
# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models.
model_stride: int = (
8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models.
)

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
Expand Down Expand Up @@ -274,6 +274,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
batch_size=cfg.batch_size,
manifest=manifest,
filepaths=filepaths,
accelerator=accelerator,
)

output_filename, pred_text_attr_name = write_transcription(
Expand Down
42 changes: 23 additions & 19 deletions examples/asr/asr_vad/speech_to_text_with_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from typing import Callable, Optional

import torch
import torch.amp
import yaml
from omegaconf import DictConfig, OmegaConf
from torch.profiler import ProfilerActivity, profile, record_function
Expand All @@ -84,14 +85,6 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging

try:
from torch.cuda.amp import autocast
except ImportError:

@contextlib.contextmanager
def autocast(enabled=None):
yield


@dataclass
class InferenceConfig:
Expand All @@ -105,9 +98,9 @@ class InferenceConfig:
use_rttm: bool = True # whether to use RTTM
rttm_mode: str = "mask" # how to use RTTM files, choices=[`mask`, `drop`]
feat_mask_val: Optional[float] = None # value used to mask features based on RTTM, set None to use defaults
normalize: Optional[
str
] = "post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`]
normalize: Optional[str] = (
"post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`]
)
normalize_type: str = "per_feature" # how to determine mean and std used for normalization
normalize_audio_db: Optional[float] = None # set to normalize RMS DB of audio before extracting audio features

Expand All @@ -117,7 +110,9 @@ class InferenceConfig:
batch_size: int = 1 # batch size for ASR. Feature extraction and VAD only support single sample per batch.
num_workers: int = 8
sample_rate: int = 16000
frame_unit_time_secs: float = 0.01 # unit time per frame in seconds, equal to `window_stride` in ASR configs, typically 10ms.
frame_unit_time_secs: float = (
0.01 # unit time per frame in seconds, equal to `window_stride` in ASR configs, typically 10ms.
)
audio_type: str = "wav"

# Output settings, no need to change
Expand Down Expand Up @@ -263,7 +258,9 @@ def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: C
'vad_stream': False,
'sample_rate': cfg.sample_rate,
'manifest_filepath': manifest_filepath,
'labels': ['infer',],
'labels': [
'infer',
],
'num_workers': cfg.num_workers,
'shuffle': False,
'normalize_audio_db': cfg.normalize_audio_db,
Expand All @@ -274,10 +271,11 @@ def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: C
with record_fn("feat_extract_loop"):
for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))):
test_batch = [x.to(vad_model.device) for x in test_batch]
with autocast():
with torch.amp.autocast(vad_model.device.type):
with record_fn("feat_extract_infer"):
processed_signal, processed_signal_length = vad_model.preprocessor(
input_signal=test_batch[0], length=test_batch[1],
input_signal=test_batch[0],
length=test_batch[1],
)
with record_fn("feat_extract_other"):
processed_signal = processed_signal.squeeze(0)[:, :processed_signal_length]
Expand Down Expand Up @@ -317,7 +315,9 @@ def run_vad_inference(manifest_filepath: str, cfg: DictConfig, record_fn: Callab
test_data_config = {
'vad_stream': True,
'manifest_filepath': manifest_filepath,
'labels': ['infer',],
'labels': [
'infer',
],
'num_workers': cfg.num_workers,
'shuffle': False,
'window_length_in_sec': vad_cfg.vad.parameters.window_length_in_sec,
Expand Down Expand Up @@ -438,7 +438,7 @@ def generate_vad_frame_pred(
with record_fn("vad_infer_loop"):
for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))):
test_batch = [x.to(vad_model.device) for x in test_batch]
with autocast():
with torch.amp.autocast(vad_model.device.type):
with record_fn("vad_infer_model"):
if use_feat:
log_probs = vad_model(processed_signal=test_batch[0], processed_signal_length=test_batch[1])
Expand Down Expand Up @@ -572,7 +572,7 @@ def run_asr_inference(manifest_filepath, cfg, record_fn) -> str:
hypotheses = []
all_hypotheses = []
t0 = time.time()
with autocast():
with torch.amp.autocast(asr_model.device.type):
with torch.no_grad():
with record_fn("asr_infer_loop"):
for test_batch in tqdm(dataloader, desc="Transcribing"):
Expand All @@ -585,7 +585,11 @@ def run_asr_inference(manifest_filepath, cfg, record_fn) -> str:
with record_fn("asr_infer_other"):
logits, logits_len = outputs[0], outputs[1]

current_hypotheses, all_hyp = decode_function(logits, logits_len, return_hypotheses=False,)
current_hypotheses, all_hyp = decode_function(
logits,
logits_len,
return_hypotheses=False,
)
if isinstance(current_hypotheses, tuple) and len(current_hypotheses) == 2:
current_hypotheses = current_hypotheses[0] # handle RNNT output

Expand Down
17 changes: 6 additions & 11 deletions examples/asr/experimental/sclite/speech_to_text_sclite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest
from nemo.utils import logging

try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager

@contextmanager
def autocast(enabled=None):
yield


def score_with_sctk(sctk_dir, ref_fname, hyp_fname, out_dir, glm=""):
sclite_path = os.path.join(sctk_dir, "bin", "sclite")
Expand Down Expand Up @@ -91,7 +82,11 @@ def get_utt_info(manifest_path):
def main():
parser = ArgumentParser()
parser.add_argument(
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=False, help="Pass: 'QuartzNet15x5Base-En'",
"--asr_model",
type=str,
default="QuartzNet15x5Base-En",
required=False,
help="Pass: 'QuartzNet15x5Base-En'",
)
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
parser.add_argument("--batch_size", type=int, default=4)
Expand Down Expand Up @@ -123,7 +118,7 @@ def main():
references = [data['text'] for data in manifest_data]
audio_filepaths = [data['audio_filepath'] for data in manifest_data]

with autocast():
with torch.amp.autocast(asr_model.device.type):
hypotheses = asr_model.transcribe(audio_filepaths, batch_size=args.batch_size)

# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
Expand Down
21 changes: 6 additions & 15 deletions examples/asr/quantization/speech_to_text_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,17 @@
"https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
)

try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager

@contextmanager
def autocast(enabled=None):
yield


can_gpu = torch.cuda.is_available()


def main():
parser = ArgumentParser()
parser.add_argument(
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'",
"--asr_model",
type=str,
default="QuartzNet15x5Base-En",
required=True,
help="Pass: 'QuartzNet15x5Base-En'",
)
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
parser.add_argument("--batch_size", type=int, default=256)
Expand Down Expand Up @@ -118,11 +112,8 @@ def main():
for i, test_batch in enumerate(asr_model.test_dataloader()):
if can_gpu:
test_batch = [x.cuda() for x in test_batch]
if args.amp:
with autocast():
with torch.amp.autocast(asr_model.device.type, enabled=args.amp):
_ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
else:
_ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
if i >= args.num_calib_batch:
break

Expand Down
18 changes: 6 additions & 12 deletions examples/asr/quantization/speech_to_text_quant_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,17 @@
)


try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager

@contextmanager
def autocast(enabled=None):
yield


can_gpu = torch.cuda.is_available()


def main():
parser = ArgumentParser()
parser.add_argument(
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'",
"--asr_model",
type=str,
default="QuartzNet15x5Base-En",
required=True,
help="Pass: 'QuartzNet15x5Base-En'",
)
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
parser.add_argument("--wer_target", type=float, default=None, help="used by test")
Expand Down Expand Up @@ -199,7 +193,7 @@ def evaluate(asr_model, labels_map, wer):
for test_batch in asr_model.test_dataloader():
if can_gpu:
test_batch = [x.cuda() for x in test_batch]
with autocast():
with torch.amp.autocast(asr_model.device.type):
log_probs, encoded_len, greedy_predictions = asr_model(
input_signal=test_batch[0], input_signal_length=test_batch[1]
)
Expand Down
Loading
Loading