Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix asr warnings #10469

Merged
merged 27 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4918ef9
check
Sep 10, 2024
a69b723
1
Sep 10, 2024
233dcc3
added to examples/asr
Sep 11, 2024
e8cea2e
deprecates cuda.amp.autocast to replace with amp.autocast(**Args)
Sep 11, 2024
7c11cac
Apply isort and black reformatting
nithinraok Sep 11, 2024
e197d3b
check
Sep 10, 2024
466286a
1
Sep 10, 2024
f081911
added to examples/asr
Sep 11, 2024
0122720
deprecates cuda.amp.autocast to replace with amp.autocast(**Args)
Sep 11, 2024
d8a78a6
Apply isort and black reformatting
nithinraok Sep 11, 2024
e0c711f
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 16, 2024
ef136be
Merge branch 'fix_asr_warnings' of https://github.com/NVIDIA/NeMo int…
Sep 16, 2024
e7ae53c
tested on V100
Sep 16, 2024
7cf2bcd
Apply isort and black reformatting
nithinraok Sep 16, 2024
19b654f
replace cuda for jit scripted modules
Sep 16, 2024
9e8b82e
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 16, 2024
973b729
Merge branch 'fix_asr_warnings' of https://github.com/NVIDIA/NeMo int…
Sep 17, 2024
87e6981
device type fix
Sep 17, 2024
9fe290c
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 17, 2024
a9cb5b2
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 20, 2024
8053565
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 22, 2024
dec24c7
revert diar for CI
Sep 24, 2024
4230877
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 24, 2024
db554df
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 24, 2024
7b453a5
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 25, 2024
f6d6370
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 25, 2024
95c4b7b
Merge branch 'main' into fix_asr_warnings
nithinraok Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def perform_streaming(
# would pass the whole audio at once through the model like offline mode in order to compare the results with the stremaing mode
# the output of the model in the offline and streaming mode should be exactly the same
with torch.inference_mode():
with autocast():
with autocast:
processed_signal, processed_signal_length = streaming_buffer.get_all_audios()
with torch.no_grad():
(
Expand Down Expand Up @@ -156,7 +156,7 @@ def perform_streaming(
pred_out_stream = None
for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
with torch.inference_mode():
with autocast():
with autocast:
# keep_all_outputs needs to be True for the last step of streaming when model is trained with att_context_style=regular
# otherwise the last outputs would get dropped

Expand Down Expand Up @@ -313,19 +313,7 @@ def main():
raise ValueError("Model does not support multiple lookaheads.")

global autocast
if (
args.use_amp
and torch.cuda.is_available()
and hasattr(torch.cuda, 'amp')
and hasattr(torch.cuda.amp, 'autocast')
):
logging.info("AMP enabled!\n")
autocast = torch.cuda.amp.autocast
else:

@contextlib.contextmanager
def autocast():
yield
autocast = torch.amp.autocast(asr_model.device.type, enabled=args.use_amp)

# configure the decoding config
decoding_cfg = asr_model.cfg.decoding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
# Disable config overwriting
OmegaConf.set_struct(model_cfg.preprocessor, True)

# setup AMP (optional)
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
logging.info("AMP enabled!\n")
autocast = torch.cuda.amp.autocast
else:

@contextlib.contextmanager
def autocast(*args, **kwargs):
yield

# Compute output filename
cfg = compute_output_filename(cfg, model_name)

Expand Down Expand Up @@ -208,7 +198,7 @@ def autocast(*args, **kwargs):

amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16

with autocast(dtype=amp_dtype):
with torch.amp.autocast(asr_model.device.type, enabled=cfg.amp, dtype=amp_dtype):
with torch.no_grad():
hyps = get_buffered_pred_feat_multitaskAED(
frame_asr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class TranscriptionConfig:
# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models.
model_stride: int = (
8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models.
)

# Decoding strategy for CTC models
decoding: CTCDecodingConfig = CTCDecodingConfig()
Expand Down Expand Up @@ -163,16 +165,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
# Disable config overwriting
OmegaConf.set_struct(model_cfg.preprocessor, True)

# setup AMP (optional)
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
logging.info("AMP enabled!\n")
autocast = torch.cuda.amp.autocast
else:

@contextlib.contextmanager
def autocast():
yield

# Compute output filename
cfg = compute_output_filename(cfg, model_name)

Expand Down Expand Up @@ -214,20 +206,24 @@ def autocast():
logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")

frame_asr = FrameBatchASR(
asr_model=asr_model, frame_len=chunk_len, total_buffer=cfg.total_buffer_in_secs, batch_size=cfg.batch_size,
asr_model=asr_model,
frame_len=chunk_len,
total_buffer=cfg.total_buffer_in_secs,
batch_size=cfg.batch_size,
)

hyps = get_buffered_pred_feat(
frame_asr,
chunk_len,
tokens_per_chunk,
mid_delay,
model_cfg.preprocessor,
model_stride_in_secs,
asr_model.device,
manifest,
filepaths,
)
with torch.amp.autocast(asr_model.device.type, enabled=cfg.amp):
hyps = get_buffered_pred_feat(
frame_asr,
chunk_len,
tokens_per_chunk,
mid_delay,
model_cfg.preprocessor,
model_stride_in_secs,
asr_model.device,
manifest,
filepaths,
)
output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging

can_gpu = torch.cuda.is_available()


@dataclass
class TranscriptionConfig:
Expand All @@ -112,7 +110,9 @@ class TranscriptionConfig:
# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models.
model_stride: int = (
8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models.
)

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
Expand Down Expand Up @@ -274,6 +274,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
batch_size=cfg.batch_size,
manifest=manifest,
filepaths=filepaths,
accelerator=accelerator,
)

output_filename, pred_text_attr_name = write_transcription(
Expand Down
42 changes: 23 additions & 19 deletions examples/asr/asr_vad/speech_to_text_with_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from typing import Callable, Optional

import torch
import torch.amp
import yaml
from omegaconf import DictConfig, OmegaConf
from torch.profiler import ProfilerActivity, profile, record_function
Expand All @@ -84,14 +85,6 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging

try:
from torch.cuda.amp import autocast
except ImportError:

@contextlib.contextmanager
def autocast(enabled=None):
yield


@dataclass
class InferenceConfig:
Expand All @@ -105,9 +98,9 @@ class InferenceConfig:
use_rttm: bool = True # whether to use RTTM
rttm_mode: str = "mask" # how to use RTTM files, choices=[`mask`, `drop`]
feat_mask_val: Optional[float] = None # value used to mask features based on RTTM, set None to use defaults
normalize: Optional[
str
] = "post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`]
normalize: Optional[str] = (
"post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`]
)
normalize_type: str = "per_feature" # how to determine mean and std used for normalization
normalize_audio_db: Optional[float] = None # set to normalize RMS DB of audio before extracting audio features

Expand All @@ -117,7 +110,9 @@ class InferenceConfig:
batch_size: int = 1 # batch size for ASR. Feature extraction and VAD only support single sample per batch.
num_workers: int = 8
sample_rate: int = 16000
frame_unit_time_secs: float = 0.01 # unit time per frame in seconds, equal to `window_stride` in ASR configs, typically 10ms.
frame_unit_time_secs: float = (
0.01 # unit time per frame in seconds, equal to `window_stride` in ASR configs, typically 10ms.
)
audio_type: str = "wav"

# Output settings, no need to change
Expand Down Expand Up @@ -263,7 +258,9 @@ def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: C
'vad_stream': False,
'sample_rate': cfg.sample_rate,
'manifest_filepath': manifest_filepath,
'labels': ['infer',],
'labels': [
'infer',
],
'num_workers': cfg.num_workers,
'shuffle': False,
'normalize_audio_db': cfg.normalize_audio_db,
Expand All @@ -274,10 +271,11 @@ def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: C
with record_fn("feat_extract_loop"):
for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))):
test_batch = [x.to(vad_model.device) for x in test_batch]
with autocast():
with torch.amp.autocast(vad_model.device.type):
with record_fn("feat_extract_infer"):
processed_signal, processed_signal_length = vad_model.preprocessor(
input_signal=test_batch[0], length=test_batch[1],
input_signal=test_batch[0],
length=test_batch[1],
)
with record_fn("feat_extract_other"):
processed_signal = processed_signal.squeeze(0)[:, :processed_signal_length]
Expand Down Expand Up @@ -317,7 +315,9 @@ def run_vad_inference(manifest_filepath: str, cfg: DictConfig, record_fn: Callab
test_data_config = {
'vad_stream': True,
'manifest_filepath': manifest_filepath,
'labels': ['infer',],
'labels': [
'infer',
],
'num_workers': cfg.num_workers,
'shuffle': False,
'window_length_in_sec': vad_cfg.vad.parameters.window_length_in_sec,
Expand Down Expand Up @@ -438,7 +438,7 @@ def generate_vad_frame_pred(
with record_fn("vad_infer_loop"):
for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))):
test_batch = [x.to(vad_model.device) for x in test_batch]
with autocast():
with torch.amp.autocast(vad_model.device.type):
with record_fn("vad_infer_model"):
if use_feat:
log_probs = vad_model(processed_signal=test_batch[0], processed_signal_length=test_batch[1])
Expand Down Expand Up @@ -572,7 +572,7 @@ def run_asr_inference(manifest_filepath, cfg, record_fn) -> str:
hypotheses = []
all_hypotheses = []
t0 = time.time()
with autocast():
with torch.amp.autocast(asr_model.device.type):
with torch.no_grad():
with record_fn("asr_infer_loop"):
for test_batch in tqdm(dataloader, desc="Transcribing"):
Expand All @@ -585,7 +585,11 @@ def run_asr_inference(manifest_filepath, cfg, record_fn) -> str:
with record_fn("asr_infer_other"):
logits, logits_len = outputs[0], outputs[1]

current_hypotheses, all_hyp = decode_function(logits, logits_len, return_hypotheses=False,)
current_hypotheses, all_hyp = decode_function(
logits,
logits_len,
return_hypotheses=False,
)
if isinstance(current_hypotheses, tuple) and len(current_hypotheses) == 2:
current_hypotheses = current_hypotheses[0] # handle RNNT output

Expand Down
17 changes: 6 additions & 11 deletions examples/asr/experimental/sclite/speech_to_text_sclite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest
from nemo.utils import logging

try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager

@contextmanager
def autocast(enabled=None):
yield


def score_with_sctk(sctk_dir, ref_fname, hyp_fname, out_dir, glm=""):
sclite_path = os.path.join(sctk_dir, "bin", "sclite")
Expand Down Expand Up @@ -91,7 +82,11 @@ def get_utt_info(manifest_path):
def main():
parser = ArgumentParser()
parser.add_argument(
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=False, help="Pass: 'QuartzNet15x5Base-En'",
"--asr_model",
type=str,
default="QuartzNet15x5Base-En",
required=False,
help="Pass: 'QuartzNet15x5Base-En'",
)
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
parser.add_argument("--batch_size", type=int, default=4)
Expand Down Expand Up @@ -123,7 +118,7 @@ def main():
references = [data['text'] for data in manifest_data]
audio_filepaths = [data['audio_filepath'] for data in manifest_data]

with autocast():
with torch.amp.autocast(asr_model.device.type):
hypotheses = asr_model.transcribe(audio_filepaths, batch_size=args.batch_size)

# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
Expand Down
21 changes: 6 additions & 15 deletions examples/asr/quantization/speech_to_text_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,17 @@
"https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
)

try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager

@contextmanager
def autocast(enabled=None):
yield


can_gpu = torch.cuda.is_available()


def main():
parser = ArgumentParser()
parser.add_argument(
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'",
"--asr_model",
type=str,
default="QuartzNet15x5Base-En",
required=True,
help="Pass: 'QuartzNet15x5Base-En'",
)
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
parser.add_argument("--batch_size", type=int, default=256)
Expand Down Expand Up @@ -118,11 +112,8 @@ def main():
for i, test_batch in enumerate(asr_model.test_dataloader()):
if can_gpu:
test_batch = [x.cuda() for x in test_batch]
if args.amp:
with autocast():
with torch.amp.autocast(asr_model.device.type, enabled=args.amp):
_ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
else:
_ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
if i >= args.num_calib_batch:
break

Expand Down
18 changes: 6 additions & 12 deletions examples/asr/quantization/speech_to_text_quant_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,17 @@
)


try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager

@contextmanager
def autocast(enabled=None):
yield


can_gpu = torch.cuda.is_available()


def main():
parser = ArgumentParser()
parser.add_argument(
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'",
"--asr_model",
type=str,
default="QuartzNet15x5Base-En",
required=True,
help="Pass: 'QuartzNet15x5Base-En'",
)
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
parser.add_argument("--wer_target", type=float, default=None, help="used by test")
Expand Down Expand Up @@ -199,7 +193,7 @@ def evaluate(asr_model, labels_map, wer):
for test_batch in asr_model.test_dataloader():
if can_gpu:
test_batch = [x.cuda() for x in test_batch]
with autocast():
with torch.amp.autocast(asr_model.device.type):
log_probs, encoded_len, greedy_predictions = asr_model(
input_signal=test_batch[0], input_signal_length=test_batch[1]
)
Expand Down
Loading
Loading