diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 476d8ff70786..55852ee3ba8f 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -127,7 +127,7 @@ def perform_streaming( # would pass the whole audio at once through the model like offline mode in order to compare the results with the stremaing mode # the output of the model in the offline and streaming mode should be exactly the same with torch.inference_mode(): - with autocast(): + with autocast: processed_signal, processed_signal_length = streaming_buffer.get_all_audios() with torch.no_grad(): ( @@ -156,7 +156,7 @@ def perform_streaming( pred_out_stream = None for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): with torch.inference_mode(): - with autocast(): + with autocast: # keep_all_outputs needs to be True for the last step of streaming when model is trained with att_context_style=regular # otherwise the last outputs would get dropped @@ -313,19 +313,7 @@ def main(): raise ValueError("Model does not support multiple lookaheads.") global autocast - if ( - args.use_amp - and torch.cuda.is_available() - and hasattr(torch.cuda, 'amp') - and hasattr(torch.cuda.amp, 'autocast') - ): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(): - yield + autocast = torch.amp.autocast(asr_model.device.type, enabled=args.use_amp) # configure the decoding config decoding_cfg = asr_model.cfg.decoding diff --git a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py index 39b7547923cd..0195c1edd239 100644 --- a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py +++ b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py @@ -170,16 +170,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # Disable config overwriting OmegaConf.set_struct(model_cfg.preprocessor, True) - # setup AMP (optional) - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(*args, **kwargs): - yield - # Compute output filename cfg = compute_output_filename(cfg, model_name) @@ -208,7 +198,7 @@ def autocast(*args, **kwargs): amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 - with autocast(dtype=amp_dtype): + with torch.amp.autocast(asr_model.device.type, enabled=cfg.amp, dtype=amp_dtype): with torch.no_grad(): hyps = get_buffered_pred_feat_multitaskAED( frame_asr, diff --git a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py index 1d01e8e6a7a1..3feef6a027b8 100644 --- a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py +++ b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py @@ -88,7 +88,9 @@ class TranscriptionConfig: # Chunked configs chunk_len_in_secs: float = 1.6 # Chunk length in seconds total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds - model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models. + model_stride: int = ( + 8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models. + ) # Decoding strategy for CTC models decoding: CTCDecodingConfig = CTCDecodingConfig() @@ -163,16 +165,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # Disable config overwriting OmegaConf.set_struct(model_cfg.preprocessor, True) - # setup AMP (optional) - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(): - yield - # Compute output filename cfg = compute_output_filename(cfg, model_name) @@ -214,20 +206,24 @@ def autocast(): logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") frame_asr = FrameBatchASR( - asr_model=asr_model, frame_len=chunk_len, total_buffer=cfg.total_buffer_in_secs, batch_size=cfg.batch_size, + asr_model=asr_model, + frame_len=chunk_len, + total_buffer=cfg.total_buffer_in_secs, + batch_size=cfg.batch_size, ) - hyps = get_buffered_pred_feat( - frame_asr, - chunk_len, - tokens_per_chunk, - mid_delay, - model_cfg.preprocessor, - model_stride_in_secs, - asr_model.device, - manifest, - filepaths, - ) + with torch.amp.autocast(asr_model.device.type, enabled=cfg.amp): + hyps = get_buffered_pred_feat( + frame_asr, + chunk_len, + tokens_per_chunk, + mid_delay, + model_cfg.preprocessor, + model_stride_in_secs, + asr_model.device, + manifest, + filepaths, + ) output_filename, pred_text_attr_name = write_transcription( hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False ) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py index ea82796eab39..2014d8782bca 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py @@ -84,8 +84,6 @@ from nemo.core.config import hydra_runner from nemo.utils import logging -can_gpu = torch.cuda.is_available() - @dataclass class TranscriptionConfig: @@ -112,7 +110,9 @@ class TranscriptionConfig: # Chunked configs chunk_len_in_secs: float = 1.6 # Chunk length in seconds total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds - model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models. + model_stride: int = ( + 8 # Model downsampling factor, 8 for Citrinet and FastConformer models and 4 for Conformer models. + ) # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA # device anyway, and do inference on CPU only if CUDA device is not found. @@ -274,6 +274,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: batch_size=cfg.batch_size, manifest=manifest, filepaths=filepaths, + accelerator=accelerator, ) output_filename, pred_text_attr_name = write_transcription( diff --git a/examples/asr/asr_vad/speech_to_text_with_vad.py b/examples/asr/asr_vad/speech_to_text_with_vad.py index 391f299aa441..27ca7bc1f84c 100644 --- a/examples/asr/asr_vad/speech_to_text_with_vad.py +++ b/examples/asr/asr_vad/speech_to_text_with_vad.py @@ -63,6 +63,7 @@ from typing import Callable, Optional import torch +import torch.amp import yaml from omegaconf import DictConfig, OmegaConf from torch.profiler import ProfilerActivity, profile, record_function @@ -84,14 +85,6 @@ from nemo.core.config import hydra_runner from nemo.utils import logging -try: - from torch.cuda.amp import autocast -except ImportError: - - @contextlib.contextmanager - def autocast(enabled=None): - yield - @dataclass class InferenceConfig: @@ -105,9 +98,9 @@ class InferenceConfig: use_rttm: bool = True # whether to use RTTM rttm_mode: str = "mask" # how to use RTTM files, choices=[`mask`, `drop`] feat_mask_val: Optional[float] = None # value used to mask features based on RTTM, set None to use defaults - normalize: Optional[ - str - ] = "post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`] + normalize: Optional[str] = ( + "post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`] + ) normalize_type: str = "per_feature" # how to determine mean and std used for normalization normalize_audio_db: Optional[float] = None # set to normalize RMS DB of audio before extracting audio features @@ -117,7 +110,9 @@ class InferenceConfig: batch_size: int = 1 # batch size for ASR. Feature extraction and VAD only support single sample per batch. num_workers: int = 8 sample_rate: int = 16000 - frame_unit_time_secs: float = 0.01 # unit time per frame in seconds, equal to `window_stride` in ASR configs, typically 10ms. + frame_unit_time_secs: float = ( + 0.01 # unit time per frame in seconds, equal to `window_stride` in ASR configs, typically 10ms. + ) audio_type: str = "wav" # Output settings, no need to change @@ -263,7 +258,9 @@ def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: C 'vad_stream': False, 'sample_rate': cfg.sample_rate, 'manifest_filepath': manifest_filepath, - 'labels': ['infer',], + 'labels': [ + 'infer', + ], 'num_workers': cfg.num_workers, 'shuffle': False, 'normalize_audio_db': cfg.normalize_audio_db, @@ -274,10 +271,11 @@ def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: C with record_fn("feat_extract_loop"): for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))): test_batch = [x.to(vad_model.device) for x in test_batch] - with autocast(): + with torch.amp.autocast(vad_model.device.type): with record_fn("feat_extract_infer"): processed_signal, processed_signal_length = vad_model.preprocessor( - input_signal=test_batch[0], length=test_batch[1], + input_signal=test_batch[0], + length=test_batch[1], ) with record_fn("feat_extract_other"): processed_signal = processed_signal.squeeze(0)[:, :processed_signal_length] @@ -317,7 +315,9 @@ def run_vad_inference(manifest_filepath: str, cfg: DictConfig, record_fn: Callab test_data_config = { 'vad_stream': True, 'manifest_filepath': manifest_filepath, - 'labels': ['infer',], + 'labels': [ + 'infer', + ], 'num_workers': cfg.num_workers, 'shuffle': False, 'window_length_in_sec': vad_cfg.vad.parameters.window_length_in_sec, @@ -438,7 +438,7 @@ def generate_vad_frame_pred( with record_fn("vad_infer_loop"): for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))): test_batch = [x.to(vad_model.device) for x in test_batch] - with autocast(): + with torch.amp.autocast(vad_model.device.type): with record_fn("vad_infer_model"): if use_feat: log_probs = vad_model(processed_signal=test_batch[0], processed_signal_length=test_batch[1]) @@ -572,7 +572,7 @@ def run_asr_inference(manifest_filepath, cfg, record_fn) -> str: hypotheses = [] all_hypotheses = [] t0 = time.time() - with autocast(): + with torch.amp.autocast(asr_model.device.type): with torch.no_grad(): with record_fn("asr_infer_loop"): for test_batch in tqdm(dataloader, desc="Transcribing"): @@ -585,7 +585,11 @@ def run_asr_inference(manifest_filepath, cfg, record_fn) -> str: with record_fn("asr_infer_other"): logits, logits_len = outputs[0], outputs[1] - current_hypotheses, all_hyp = decode_function(logits, logits_len, return_hypotheses=False,) + current_hypotheses, all_hyp = decode_function( + logits, + logits_len, + return_hypotheses=False, + ) if isinstance(current_hypotheses, tuple) and len(current_hypotheses) == 2: current_hypotheses = current_hypotheses[0] # handle RNNT output diff --git a/examples/asr/experimental/sclite/speech_to_text_sclite.py b/examples/asr/experimental/sclite/speech_to_text_sclite.py index 80a47585e000..ffbf629b3ed3 100644 --- a/examples/asr/experimental/sclite/speech_to_text_sclite.py +++ b/examples/asr/experimental/sclite/speech_to_text_sclite.py @@ -42,15 +42,6 @@ from nemo.collections.asr.parts.utils.manifest_utils import read_manifest from nemo.utils import logging -try: - from torch.cuda.amp import autocast -except ImportError: - from contextlib import contextmanager - - @contextmanager - def autocast(enabled=None): - yield - def score_with_sctk(sctk_dir, ref_fname, hyp_fname, out_dir, glm=""): sclite_path = os.path.join(sctk_dir, "bin", "sclite") @@ -91,7 +82,11 @@ def get_utt_info(manifest_path): def main(): parser = ArgumentParser() parser.add_argument( - "--asr_model", type=str, default="QuartzNet15x5Base-En", required=False, help="Pass: 'QuartzNet15x5Base-En'", + "--asr_model", + type=str, + default="QuartzNet15x5Base-En", + required=False, + help="Pass: 'QuartzNet15x5Base-En'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=4) @@ -123,7 +118,7 @@ def main(): references = [data['text'] for data in manifest_data] audio_filepaths = [data['audio_filepath'] for data in manifest_data] - with autocast(): + with torch.amp.autocast(asr_model.device.type): hypotheses = asr_model.transcribe(audio_filepaths, batch_size=args.batch_size) # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis diff --git a/examples/asr/quantization/speech_to_text_calibrate.py b/examples/asr/quantization/speech_to_text_calibrate.py index 264806c7b1ba..f5ec6e76fa27 100644 --- a/examples/asr/quantization/speech_to_text_calibrate.py +++ b/examples/asr/quantization/speech_to_text_calibrate.py @@ -35,23 +35,17 @@ "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." ) -try: - from torch.cuda.amp import autocast -except ImportError: - from contextlib import contextmanager - - @contextmanager - def autocast(enabled=None): - yield - - can_gpu = torch.cuda.is_available() def main(): parser = ArgumentParser() parser.add_argument( - "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'", + "--asr_model", + type=str, + default="QuartzNet15x5Base-En", + required=True, + help="Pass: 'QuartzNet15x5Base-En'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=256) @@ -118,11 +112,8 @@ def main(): for i, test_batch in enumerate(asr_model.test_dataloader()): if can_gpu: test_batch = [x.cuda() for x in test_batch] - if args.amp: - with autocast(): + with torch.amp.autocast(asr_model.device.type, enabled=args.amp): _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) - else: - _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) if i >= args.num_calib_batch: break diff --git a/examples/asr/quantization/speech_to_text_quant_infer.py b/examples/asr/quantization/speech_to_text_quant_infer.py index 029623cb90f0..b428db1ed83d 100644 --- a/examples/asr/quantization/speech_to_text_quant_infer.py +++ b/examples/asr/quantization/speech_to_text_quant_infer.py @@ -38,23 +38,17 @@ ) -try: - from torch.cuda.amp import autocast -except ImportError: - from contextlib import contextmanager - - @contextmanager - def autocast(enabled=None): - yield - - can_gpu = torch.cuda.is_available() def main(): parser = ArgumentParser() parser.add_argument( - "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'", + "--asr_model", + type=str, + default="QuartzNet15x5Base-En", + required=True, + help="Pass: 'QuartzNet15x5Base-En'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--wer_target", type=float, default=None, help="used by test") @@ -199,7 +193,7 @@ def evaluate(asr_model, labels_map, wer): for test_batch in asr_model.test_dataloader(): if can_gpu: test_batch = [x.cuda() for x in test_batch] - with autocast(): + with torch.amp.autocast(asr_model.device.type): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1] ) diff --git a/examples/asr/quantization/speech_to_text_quant_infer_trt.py b/examples/asr/quantization/speech_to_text_quant_infer_trt.py index e9916d6e7449..3fb982002c0c 100644 --- a/examples/asr/quantization/speech_to_text_quant_infer_trt.py +++ b/examples/asr/quantization/speech_to_text_quant_infer_trt.py @@ -43,20 +43,15 @@ can_gpu = torch.cuda.is_available() -try: - from torch.cuda.amp import autocast -except ImportError: - from contextlib import contextmanager - - @contextmanager - def autocast(enabled=None): - yield - def main(): parser = ArgumentParser() parser.add_argument( - "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'", + "--asr_model", + type=str, + default="QuartzNet15x5Base-En", + required=True, + help="Pass: 'QuartzNet15x5Base-En'", ) parser.add_argument( "--asr_onnx", @@ -145,9 +140,11 @@ def build_trt_engine(asr_model, onnx_path, qat): network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) if qat: network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION) - with builder.create_network(flags=network_flags) as network, trt.OnnxParser( - network, TRT_LOGGER - ) as parser, builder.create_builder_config() as builder_config: + with ( + builder.create_network(flags=network_flags) as network, + trt.OnnxParser(network, TRT_LOGGER) as parser, + builder.create_builder_config() as builder_config, + ): parser.parse_from_file(onnx_path) builder_config.max_workspace_size = workspace_size * (1024 * 1024) if qat: diff --git a/examples/asr/speech_translation/translate_speech.py b/examples/asr/speech_translation/translate_speech.py index 203852b52ee9..42394001255f 100644 --- a/examples/asr/speech_translation/translate_speech.py +++ b/examples/asr/speech_translation/translate_speech.py @@ -162,16 +162,6 @@ def main(cfg: TranslationConfig) -> Union[TranslationConfig, List[str]]: # prepare audio filepaths and decide wether it's partial audio filepaths, partial_audio = prepare_audio_data(cfg) - # setup AMP (optional) - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(): - yield - # Compute output filename cfg = compute_output_filename(cfg, model_name) @@ -184,10 +174,12 @@ def autocast(): return cfg # translate audio - with autocast(): + with torch.amp.autocast(asr_model.device.type, enabled=cfg.amp): with torch.no_grad(): translations = asr_model.translate( - paths2audio_files=filepaths, batch_size=cfg.batch_size, return_hypotheses=return_hypotheses, + paths2audio_files=filepaths, + batch_size=cfg.batch_size, + return_hypotheses=return_hypotheses, ) logging.info(f"Finished translating {len(filepaths)} files !") diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index a8df6bc5a911..f3a1c3fc8162 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -358,16 +358,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis filepaths = sorted_manifest_path if sorted_manifest_path is not None else filepaths - # setup AMP (optional) - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(dtype=None, enabled=True): - yield - # Compute output filename cfg = compute_output_filename(cfg, model_name) @@ -393,7 +383,7 @@ def autocast(dtype=None, enabled=True): ) total_duration += item["duration"] - with autocast(dtype=amp_dtype, enabled=cfg.amp): + with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=amp_dtype, enabled=cfg.amp): with torch.no_grad(): if cfg.calculate_rtfx: start_time = time.time() diff --git a/examples/slu/speech_intent_slot/eval_utils/inference.py b/examples/slu/speech_intent_slot/eval_utils/inference.py index d83d48b688fc..9bd76c76822d 100644 --- a/examples/slu/speech_intent_slot/eval_utils/inference.py +++ b/examples/slu/speech_intent_slot/eval_utils/inference.py @@ -14,7 +14,6 @@ # limitations under the License. -import contextlib import glob import json import os @@ -60,7 +59,12 @@ class InferenceConfig: sequence_generator: SequenceGeneratorConfig = SequenceGeneratorConfig(type="greedy") -def slurp_inference(model, path2manifest: str, batch_size: int = 4, num_workers: int = 0,) -> List[str]: +def slurp_inference( + model, + path2manifest: str, + batch_size: int = 4, + num_workers: int = 0, +) -> List[str]: if num_workers is None: num_workers = min(batch_size, os.cpu_count() - 1) @@ -178,16 +182,6 @@ def run_inference(cfg: InferenceConfig) -> InferenceConfig: logging.info(f"\nStart inference with {len(filepaths)} files...\n") - # setup AMP (optional) - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(): - yield - # Compute output filename if cfg.output_filename is None: # create default output filename @@ -206,7 +200,7 @@ def autocast(): return cfg # transcribe audio - with autocast(): + with torch.amp.autocast(model.device.type, enabled=cfg.amp): with torch.no_grad(): predictions = slurp_inference( model=model, diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index 98e56a7be48d..ddcc269bedcc 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -49,15 +49,6 @@ from nemo.core.classes import Model from nemo.utils import logging, model_utils -try: - from torch.cuda.amp import autocast -except ImportError: - from contextlib import contextmanager - - @contextmanager - def autocast(enabled=None): - yield - __all__ = ['ClusteringDiarizer'] @@ -223,7 +214,7 @@ def _run_vad(self, manifest_file): tqdm(self._vad_model.test_dataloader(), desc='vad', leave=True, disable=not self.verbose) ): test_batch = [x.to(self._vad_model.device) for x in test_batch] - with autocast(): + with torch.amp.autocast(self._vad_model.device.type): log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) probs = torch.softmax(log_probs, dim=-1) pred = probs[:, 1] @@ -359,7 +350,7 @@ def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: in ): test_batch = [x.to(self._speaker_model.device) for x in test_batch] audio_signal, audio_signal_len, labels, slices = test_batch - with autocast(): + with torch.amp.autocast(self._speaker_model.device.type): _, embs = self._speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) emb_shape = embs.shape[-1] embs = embs.view(-1, emb_shape) diff --git a/nemo/collections/asr/parts/k2/graph_transducer.py b/nemo/collections/asr/parts/k2/graph_transducer.py index 5de8064224a1..bcd49bcbd7a9 100644 --- a/nemo/collections/asr/parts/k2/graph_transducer.py +++ b/nemo/collections/asr/parts/k2/graph_transducer.py @@ -25,7 +25,7 @@ def force_float32_context() -> ContextManager: """Get context manager to force float32 precision in autocast mode.""" if torch.is_autocast_enabled(): - return torch.cuda.amp.autocast(dtype=torch.float32) + return torch.amp.autocast('cuda', dtype=torch.float32) return nullcontext() @@ -159,7 +159,10 @@ def get_graphs_batched( # composed version text_fsas = [ - self.get_unit_schema(units_tensor=targets[i, : target_lengths[i].item()], vocab_size=vocab_size,) + self.get_unit_schema( + units_tensor=targets[i, : target_lengths[i].item()], + vocab_size=vocab_size, + ) for i in range(batch_size) ] temporal_fsas = [ @@ -192,7 +195,8 @@ def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) scores_to_batch_i = torch.repeat_interleave( torch.arange(batch_size, device=device, dtype=torch.int64), torch.tensor( - [target_fsas_vec.arcs.index(0, i)[0].values().shape[0] for i in range(batch_size)], device=device, + [target_fsas_vec.arcs.index(0, i)[0].values().shape[0] for i in range(batch_size)], + device=device, ), ) indices = ( @@ -442,7 +446,11 @@ def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int) return rnnt_graph def forward( - self, acts: torch.Tensor, labels: torch.Tensor, act_lens: torch.Tensor, label_lens: torch.Tensor, + self, + acts: torch.Tensor, + labels: torch.Tensor, + act_lens: torch.Tensor, + label_lens: torch.Tensor, ) -> torch.Tensor: """ Compute forward method for RNN-T. diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index d70737b5135b..5138a3148e91 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -433,7 +433,7 @@ def forward(self, x, seq_len, linear_spec=False): x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) # disable autocast to get full range of stft values - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(x.device.type, enabled=False): x = self.stft(x) # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude @@ -627,7 +627,7 @@ def _apply_log(self, features: torch.Tensor) -> torch.Tensor: def _extract_spectrograms(self, signals: torch.Tensor) -> torch.Tensor: # Complex FFT needs to be done in single precision - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): features = self._mel_spec_extractor(waveform=signals) return features diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index 78f81ee555bc..ec0def1b3ebb 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -473,7 +473,7 @@ def forward_for_export(self, x, lengths): self.set_max_len(max_len) dtype = x.dtype # Computes in float32 to avoid instabilities during training with AMP. - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(x.device.type, enabled=False): # Create sample mask - 1 represents value, 0 represents pad mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device) mask = ~mask # 0 represents value, 1 represents pad diff --git a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py index 96f90bee363c..262c98401f95 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import copy import os from pathlib import Path @@ -68,7 +67,7 @@ def run_confidence_benchmark( batch_size: int = 8, num_workers: int = 4, plot_dir: Optional[Union[str, Path]] = None, - autocast: Optional = None, + use_amp: Optional[bool] = False, ): """Run benchmark and plot histograms and curves, if plot_dir is provided. @@ -81,15 +80,8 @@ def run_confidence_benchmark( plot_dir = Path(plot_dir) is_rnnt = isinstance(model, EncDecRNNTModel) - # setup autocast if necessary - if autocast is None: - - @contextlib.contextmanager - def autocast(): - yield - # transcribe audio - with autocast(): + with torch.amp.autocast(model.device.type, enabled=use_amp): with torch.no_grad(): transcriptions = model.transcribe( audio=filepaths, batch_size=batch_size, return_hypotheses=True, num_workers=num_workers diff --git a/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py b/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py index c39ff7da58d9..59e050c5f656 100644 --- a/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py +++ b/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py @@ -449,7 +449,7 @@ def run_ASR_QuartzNet_CTC(self, asr_model: Type[EncDecCTCModel]) -> Tuple[Dict, log_prediction=asr_model._cfg.get("log_prediction", False), ) - with torch.cuda.amp.autocast(): + with torch.amp.autocast(asr_model.device.type): transcript_hyps_list = asr_model.transcribe( self.audio_file_list, batch_size=self.asr_batch_size, return_hypotheses=True ) # type: List[nemo_asr.parts.Hypothesis] @@ -577,7 +577,7 @@ def run_ASR_CitriNet_CTC(self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict log_prediction=asr_model._cfg.get("log_prediction", False), ) - with torch.cuda.amp.autocast(): + with torch.amp.autocast(asr_model.device.type): transcript_hyps_list = asr_model.transcribe( self.audio_file_list, batch_size=self.asr_batch_size, return_hypotheses=True ) # type: List[nemo_asr.parts.Hypothesis] @@ -671,7 +671,7 @@ def run_ASR_BPE_CTC(self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict, Dic onset_delay, mid_delay, tokens_per_chunk = self.set_buffered_infer_params(asr_model) onset_delay_in_sec = round(onset_delay * self.model_stride_in_secs, 2) - with torch.cuda.amp.autocast(): + with torch.amp.autocast(asr_model.device.type): logging.info(f"Running ASR model {self.ASR_model_name}") for idx, audio_file_path in enumerate(self.audio_file_list): diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index cb5d21bf760a..c1e712c44aeb 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from pathlib import Path from tempfile import NamedTemporaryFile -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from omegaconf import DictConfig @@ -42,6 +42,7 @@ def get_buffered_pred_feat_rnnt( batch_size: int, manifest: str = None, filepaths: List[list] = None, + accelerator: Optional[str] = 'cpu', ) -> List[rnnt_utils.Hypothesis]: """ Moved from examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py @@ -67,7 +68,7 @@ def get_buffered_pred_feat_rnnt( refs.append(row['text']) with torch.inference_mode(): - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cpu' if accelerator == 'cpu' else 'cuda'): batch = [] asr.sample_offset = 0 for idx in tqdm(range(len(filepaths)), desc='Sample:', total=len(filepaths)): diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 138b2e36b7fa..29b4f7b33898 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -40,16 +40,6 @@ from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging -try: - from torch.cuda.amp import autocast -except ImportError: - from contextlib import contextmanager - - @contextmanager - def autocast(enabled=None): - yield - - """ This file contains all the utility functions required for voice activity detection. """ @@ -1127,7 +1117,7 @@ def generate_vad_frame_pred( status = get_vad_stream_status(data) for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))): test_batch = [x.to(vad_model.device) for x in test_batch] - with autocast(): + with torch.amp.autocast(vad_model.device.type): if use_feat: log_probs = vad_model(processed_signal=test_batch[0], processed_signal_length=test_batch[1]) else: diff --git a/nemo/collections/audio/modules/masking.py b/nemo/collections/audio/modules/masking.py index cfb575eea879..3f0380dccb5d 100644 --- a/nemo/collections/audio/modules/masking.py +++ b/nemo/collections/audio/modules/masking.py @@ -668,6 +668,7 @@ def forward(self, input: torch.Tensor, activity: torch.Tensor) -> torch.Tensor: """ B, num_inputs, F, T = input.shape num_outputs = activity.size(1) + device = input.device.type if activity.size(0) != B: raise ValueError(f'Batch dimension mismatch: activity {activity.shape} vs input {input.shape}') @@ -678,7 +679,7 @@ def forward(self, input: torch.Tensor, activity: torch.Tensor) -> torch.Tensor: if num_outputs == 1: raise ValueError(f'Expecting multiple outputs, got {num_outputs}') - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(device, enabled=False): input = input.to(dtype=self.dtype) assert input.is_complex(), f'Expecting complex input, got {input.dtype}' @@ -1039,8 +1040,9 @@ def forward( shape (B, C, F, T). """ io_dtype = input.dtype + device = input.device.type - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(device, enabled=False): output = input.to(dtype=self.dtype) if not output.is_complex(): diff --git a/nemo/collections/audio/modules/transforms.py b/nemo/collections/audio/modules/transforms.py index 6839ae0f7598..cfa0c2c8ebb7 100644 --- a/nemo/collections/audio/modules/transforms.py +++ b/nemo/collections/audio/modules/transforms.py @@ -143,7 +143,7 @@ def forward( input = input.view(B, -1, T) # STFT output (B, C, F, N) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(input.device.type, enabled=False): output = self.stft(input.float()) if self.magnitude_power != 1: @@ -265,7 +265,7 @@ def forward( input = input.view(B, -1, T) # STFT output (B, C, F, N) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(input.device.type, enabled=False): output = self.stft(input.float()) if self.magnitude_power != 1: @@ -414,7 +414,7 @@ def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = No input = input.view(B, -1, F, N) # iSTFT output (B, C, T) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(input.device.type, enabled=False): output = input.cfloat() if self.scale != 1: @@ -533,7 +533,7 @@ def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = No input = input.view(B, -1, F, N) # iSTFT output (B, C, T) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(input.device.type, enabled=False): output = input.cfloat() if self.scale != 1: diff --git a/nemo/collections/audio/parts/submodules/multichannel.py b/nemo/collections/audio/parts/submodules/multichannel.py index aff0f28cfc3a..0fa4f8bf238b 100644 --- a/nemo/collections/audio/parts/submodules/multichannel.py +++ b/nemo/collections/audio/parts/submodules/multichannel.py @@ -597,7 +597,7 @@ def forward(self, input: torch.Tensor, mask_s: torch.Tensor, mask_n: torch.Tenso """ iodtype = input.dtype - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(self.device.type, enabled=False): # Convert to double input = input.cdouble() mask_s = mask_s.double() diff --git a/nemo/collections/tts/data/dataset.py b/nemo/collections/tts/data/dataset.py index 348862ceddec..83d2b969ea91 100644 --- a/nemo/collections/tts/data/dataset.py +++ b/nemo/collections/tts/data/dataset.py @@ -504,7 +504,7 @@ def add_reference_audio(self, **kwargs): raise NotImplementedError(f"Reference audio type \"{reference_audio_type}\" is not supported.") def get_spec(self, audio): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(audio.device.type, enabled=False): spec = self.stft(audio) if spec.dtype in [torch.cfloat, torch.cdouble]: spec = torch.view_as_real(spec) @@ -512,7 +512,7 @@ def get_spec(self, audio): return spec def get_log_mel(self, audio): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(audio.device.type, enabled=False): spec = self.get_spec(audio) mel = torch.matmul(self.fb.to(spec.dtype), spec) log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny)) @@ -652,7 +652,7 @@ def __getitem__(self, index): sr=self.sample_rate, fill_na=0.0, ) - for (i, voiced_name, voiced_filepath) in non_exist_voiced_index: + for i, voiced_name, voiced_filepath in non_exist_voiced_index: my_var.__setitem__(voiced_name, torch.from_numpy(voiced_tuple[i]).float()) torch.save(my_var.get(voiced_name), voiced_filepath) @@ -859,9 +859,9 @@ def general_collate_fn(self, batch): durations_list.append(general_padding(durations, len(durations), max_durations_len)) if AlignPriorMatrix in self.sup_data_types_set: - align_prior_matrices[ - i, : align_prior_matrix.shape[0], : align_prior_matrix.shape[1] - ] = align_prior_matrix + align_prior_matrices[i, : align_prior_matrix.shape[0], : align_prior_matrix.shape[1]] = ( + align_prior_matrix + ) if Pitch in self.sup_data_types_set: pitches.append(general_padding(pitch, pitch_length.item(), max_pitches_len)) @@ -901,9 +901,9 @@ def general_collate_fn(self, batch): "p_voiced": torch.stack(p_voiceds) if P_voiced in self.sup_data_types_set else None, "audio_shifted": torch.stack(audios_shifted) if audio_shifted is not None else None, "reference_audio": torch.stack(reference_audios) if ReferenceAudio in self.sup_data_types_set else None, - "reference_audio_lens": torch.stack(reference_audio_lengths) - if ReferenceAudio in self.sup_data_types_set - else None, + "reference_audio_lens": ( + torch.stack(reference_audio_lengths) if ReferenceAudio in self.sup_data_types_set else None + ), } return data_dict @@ -1162,7 +1162,8 @@ def __len__(self): class PairedRealFakeSpectrogramsDataset(Dataset): def __init__( - self, manifest_filepath: Union[str, Path], + self, + manifest_filepath: Union[str, Path], ): manifest_filepath = Path(manifest_filepath) with Path(manifest_filepath).open() as f: @@ -1215,7 +1216,6 @@ def __init__( speaker_stats_pitch_fp: Optional[Union[str, Path]] = None, speaker_conditioning_type: Optional[str] = "per_sample", # per_sample, mean, interpolate, ): - """Dataset used for training FastPitchModel_SSL model. Requires supplementary data created using scripts/ssl_tts/make_supdata.py Args: @@ -1226,7 +1226,7 @@ def __init__( "speaker" : "duration": (Optional) sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to. - ssl_content_emb_type (str): One of ["probs", "embedding", "log_probs", "embedding_and_probs"]. + ssl_content_emb_type (str): One of ["probs", "embedding", "log_probs", "embedding_and_probs"]. Indicated which output to use as content embedding. max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load @@ -1239,18 +1239,18 @@ def __init__( trim (bool): Whether to apply `librosa.effects.trim` to trim leading and trailing silence from an audio signal. Defaults to False. pitch_conditioning (bool): Whether to load pitch contour or not - pitch_mean (Optional[float]): If using global normalization, normalize using these statistics. + pitch_mean (Optional[float]): If using global normalization, normalize using these statistics. Also used if speaker stats are not available for the given speaker - pitch_std (Optional[float]): If using global normalization, normalize using these statistics. + pitch_std (Optional[float]): If using global normalization, normalize using these statistics. Also used if speaker stats are not available for the given speaker pitch_normalization (str): Can be one of ['speaker_wise', 'global', 'none']. Indicates the kind of pitch normalization. - sup_data_dir (Optional[Union[str, Path]]): Data directory containing pre-computed embeddings/statistics. If set as - speaker_stats_pitch_fp (Optional[Union[str, Path]]): Path to the json containing speaker pitch stats. - If set as None, tries to lookup for a default filename (speaker_pitch_stats.json) in sup_data_dir. + sup_data_dir (Optional[Union[str, Path]]): Data directory containing pre-computed embeddings/statistics. If set as + speaker_stats_pitch_fp (Optional[Union[str, Path]]): Path to the json containing speaker pitch stats. + If set as None, tries to lookup for a default filename (speaker_pitch_stats.json) in sup_data_dir. Needed if we use pitch_normalization is "speaker_wise" speaker_conditioning_type (Optional[str]): Can be one of ["per_sample", "mean", "interpolate"]. Defaults to "per_sample" per_sample: Speaker embedding computed from the same utterance - mean: Speaker embedding for all utterances of a given speaker is the same and equal to the mean speaker embedding. + mean: Speaker embedding for all utterances of a given speaker is the same and equal to the mean speaker embedding. interpolate: Interpolate b/w per_sample and mean speaker embedding. """ assert ssl_content_emb_type in ["probs", "embedding", "log_probs", "embedding_and_probs"] @@ -1328,7 +1328,10 @@ def __init__( def _get_wav_from_filepath(self, audio_filepath): features = AudioSegment.segment_from_file( - audio_filepath, target_sr=self.sample_rate, n_segments=-1, trim=self.trim, + audio_filepath, + target_sr=self.sample_rate, + n_segments=-1, + trim=self.trim, ) audio_samples = features.samples @@ -1531,7 +1534,7 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): Maintain similar input lengths in a batch. Length groups are specified by boundaries. Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. - + It removes samples which are not included in the boundaries. Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. """ diff --git a/nemo/collections/tts/models/aligner.py b/nemo/collections/tts/models/aligner.py index 9aeb5fbe23ca..72d023e9ee10 100644 --- a/nemo/collections/tts/models/aligner.py +++ b/nemo/collections/tts/models/aligner.py @@ -117,12 +117,14 @@ def _setup_tokenizer(self, cfg): if "phoneme_dict" in cfg.text_tokenizer.g2p: g2p_kwargs["phoneme_dict"] = self.register_artifact( - 'text_tokenizer.g2p.phoneme_dict', cfg.text_tokenizer.g2p.phoneme_dict, + 'text_tokenizer.g2p.phoneme_dict', + cfg.text_tokenizer.g2p.phoneme_dict, ) if "heteronyms" in cfg.text_tokenizer.g2p: g2p_kwargs["heteronyms"] = self.register_artifact( - 'text_tokenizer.g2p.heteronyms', cfg.text_tokenizer.g2p.heteronyms, + 'text_tokenizer.g2p.heteronyms', + cfg.text_tokenizer.g2p.heteronyms, ) text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p, **g2p_kwargs) @@ -130,7 +132,7 @@ def _setup_tokenizer(self, cfg): self.tokenizer = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs) def forward(self, *, spec, spec_len, text, text_len, attn_prior=None): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(self.device.type, enabled=False): attn_soft, attn_logprob = self.alignment_encoder( queries=spec, keys=self.embed(text).transpose(1, 2), @@ -236,7 +238,9 @@ def _loader(self, cfg): text_tokenizer=self.tokenizer, ) return torch.utils.data.DataLoader( # noqa - dataset=dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params, + dataset=dataset, + collate_fn=dataset.collate_fn, + **cfg.dataloader_params, ) def setup_training_data(self, cfg): diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 5f7d6153a7d1..cc7019439662 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -19,8 +19,6 @@ import numpy as np import torch from torch import Tensor, nn -from torch.cuda import amp -from torch.cuda.amp import autocast as autocast from torch.nn import functional as F from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm, MaskedInstanceNorm1d @@ -96,7 +94,7 @@ def lstm_nocast(self, context: Tensor, lens: Tensor) -> Tensor: dtype = context.dtype # autocast guard is only needed for Torchscript to run in Triton # (https://github.com/pytorch/pytorch/issues/89241) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(self.device.type, enabled=False): # Calculate sizes and prepare views to our zero buffer to pass as hx max_batch_size = context.shape[0] context = context.to(dtype=torch.float32) @@ -171,7 +169,10 @@ def forward(self, context: Tensor, lens: Tensor) -> Tensor: def get_radtts_encoder( - encoder_n_convolutions=3, encoder_embedding_dim=512, encoder_kernel_size=5, norm_fn=MaskedInstanceNorm1d, + encoder_n_convolutions=3, + encoder_embedding_dim=512, + encoder_kernel_size=5, + norm_fn=MaskedInstanceNorm1d, ): return ConvLSTMLinear( in_dim=encoder_embedding_dim, @@ -203,7 +204,7 @@ def __init__(self, c): self.upper_diag = nn.Parameter(torch.diag(upper)) self.upper = nn.Parameter(torch.triu(upper, 1)) - @amp.autocast(False) + @torch.amp.autocast(device_type='cuda', enabled=False) def forward(self, z, inverse=False): U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag) L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag) @@ -280,7 +281,7 @@ def __init__( out_channels = -1 self.use_partial_padding = use_partial_padding for i in range(n_layers): - dilation = 2 ** i if with_dilation else 1 + dilation = 2**i if with_dilation else 1 padding = int((kernel_size * dilation - dilation) / 2) out_channels = min(max_channels, in_channels * 2) self.layers.append( @@ -354,7 +355,7 @@ def __init__( self.end = end for i in range(n_layers): - dilation = 2 ** i + dilation = 2**i padding = int((kernel_size * dilation - dilation) / 2) in_layer = ConvNorm( n_channels, @@ -469,7 +470,7 @@ def forward(self, z, context, inverse=False): z_reshaped = z.permute(0, 2, 1).reshape(b_s * t_s, -1) affine_params = self.param_predictor(context) q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, c_s, -1) - with amp.autocast(enabled=False): + with torch.amp.autocast(self.device.type, enabled=False): if self.use_quadratic: w = q_tilde[:, :, : self.n_bins // 2] v = q_tilde[:, :, self.n_bins // 2 :] @@ -554,7 +555,7 @@ def forward(self, z, context, inverse=False, seq_lens=None): z_1_reshaped = z_1.permute(0, 2, 1).reshape(b_s * t_s, -1) q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, n_half, self.n_bins) - with autocast(enabled=False): + with torch.amp.autocast(self.device.type, enabled=False): if self.use_quadratic: w = q_tilde[:, :, : self.n_bins // 2] v = q_tilde[:, :, self.n_bins // 2 :] diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index a7960be4cc4d..72d6c5c496d9 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -24,7 +24,7 @@ def avoid_bfloat16_autocast_context(): """ if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: - return torch.cuda.amp.autocast(dtype=torch.float32) + return torch.amp.autocast('cuda', dtype=torch.float32) else: return nullcontext() @@ -37,12 +37,12 @@ def avoid_float16_autocast_context(): if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: if torch.jit.is_scripting() or torch.jit.is_tracing(): - return torch.cuda.amp.autocast(dtype=torch.float32) + return torch.amp.autocast('cuda', dtype=torch.float32) if torch.cuda.is_bf16_supported(): - return torch.cuda.amp.autocast(dtype=torch.bfloat16) + return torch.amp.autocast('cuda', dtype=torch.bfloat16) else: - return torch.cuda.amp.autocast(dtype=torch.float32) + return torch.amp.autocast('cuda', dtype=torch.float32) else: return nullcontext() @@ -71,7 +71,7 @@ def __init__(self, mod): def forward(self, x): if torch.is_autocast_enabled() and x.dtype != torch.float32: - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(x.device.type, enabled=False): ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) else: ret = self.mod.forward(x) @@ -86,7 +86,7 @@ def __init__(self, mod): def forward(self, *args): if torch.is_autocast_enabled(): from_dtype = args[0].dtype - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(self.device.type, enabled=False): ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) else: diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 534598097bf4..8bc01f652188 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -149,7 +149,7 @@ def verify_torchscript(model, output, input_examples, check_tolerance=0.01): for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) # We disable autocast here to make sure exported TS will run under Triton or other C++ env - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): output_example = model.forward(*input_list, **input_dict) ts_model = torch.jit.load(output) all_good = all_good and run_ts_and_compare( diff --git a/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py b/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py index d9275fd26fe9..4c62a62e31b6 100644 --- a/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py +++ b/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py @@ -254,7 +254,9 @@ def decoding_step( probs_batch[prob_index].unsqueeze(0), device=packed_batch.device, dtype=packed_batch.dtype ) best_hyp_batch, beams_batch = asr_model.decoding.rnnt_decoder_predictions_tensor( - packed_batch, probs_lens, return_hypotheses=True, + packed_batch, + probs_lens, + return_hypotheses=True, ) beams_batch = [[x] for x in best_hyp_batch] @@ -356,17 +358,8 @@ def main(cfg: EvalContextBiasingConfig): durations.append(data['duration']) audio_file_paths.append(str(audio_file.absolute())) - if cfg.use_amp: - if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP is enabled!\n") - autocast = torch.cuda.amp.autocast - else: - autocast = contextlib.nullcontext - else: - autocast = contextlib.nullcontext - # manual calculation of encoder_embeddings - with autocast(): + with torch.amp.autocast(asr_model.device.type, enabled=cfg.use_amp): with torch.no_grad(): asr_model.eval() asr_model.encoder.freeze() diff --git a/scripts/asr_language_modeling/neural_rescorer/eval_neural_rescorer.py b/scripts/asr_language_modeling/neural_rescorer/eval_neural_rescorer.py index d0b4b1a61204..f2fbebd1bf4a 100644 --- a/scripts/asr_language_modeling/neural_rescorer/eval_neural_rescorer.py +++ b/scripts/asr_language_modeling/neural_rescorer/eval_neural_rescorer.py @@ -224,23 +224,13 @@ def main(): dataset = BeamScoresDataset(args.beams_file, model_tokenizer, args.eval_manifest, args.beam_size, max_seq_length) data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size) - if args.use_amp: - if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP is enabled!\n") - autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(): - yield - if "attention_mask" in inspect.getfullargspec(model.forward).args: support_att_mask = True else: support_att_mask = False logging.info(f"Rescoring with beam_size: {args.beam_size}") logging.info("Calculating the scores...") - with autocast(): + with torch.amp.autocast(model.device.type, enabled=args.use_amp): with torch.no_grad(): am_scores, lm_scores, dists, ref_lens, lens_in_chars = [], [], [], [], [] for batch in tqdm.tqdm(data_loader): diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py index 2af8283c7b82..3bb4fa4f4846 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py @@ -194,7 +194,9 @@ def beam_search_eval( ) _, beams_batch = decoding.ctc_decoder_predictions_tensor( - packed_batch, decoder_lengths=probs_lens, return_hypotheses=True, + packed_batch, + decoder_lengths=probs_lens, + return_hypotheses=True, ) for beams_idx, beams in enumerate(beams_batch): @@ -312,22 +314,7 @@ def main(cfg: EvalBeamSearchNGramConfig): ) else: - @contextlib.contextmanager - def default_autocast(): - yield - - if cfg.use_amp: - if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP is enabled!\n") - autocast = torch.cuda.amp.autocast - - else: - autocast = default_autocast - else: - - autocast = default_autocast - - with autocast(): + with torch.amp.autocast(asr_model.device.type, enabled=cfg.use_amp): with torch.no_grad(): if isinstance(asr_model, EncDecHybridRNNTCTCModel): asr_model.cur_decoder = 'ctc' diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py index 8548b839024f..c61a402c0942 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py @@ -177,7 +177,9 @@ def decoding_step( probs_batch[prob_index].unsqueeze(0), device=packed_batch.device, dtype=packed_batch.dtype ) best_hyp_batch, beams_batch = model.decoding.rnnt_decoder_predictions_tensor( - packed_batch, probs_lens, return_hypotheses=True, + packed_batch, + probs_lens, + return_hypotheses=True, ) if cfg.decoding_strategy == "greedy_batch": beams_batch = [[x] for x in best_hyp_batch] @@ -296,23 +298,8 @@ def main(cfg: EvalBeamSearchNGramConfig): ) else: - @contextlib.contextmanager - def default_autocast(): - yield - - if cfg.use_amp: - if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP is enabled!\n") - autocast = torch.cuda.amp.autocast - - else: - autocast = default_autocast - else: - - autocast = default_autocast - # manual calculation of encoder_embeddings - with autocast(): + with torch.amp.autocast(asr_model.device.type, enabled=cfg.use_amp): with torch.no_grad(): asr_model.eval() asr_model.encoder.freeze() diff --git a/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py b/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py index a1db7cec4f23..63ab24b0921e 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py @@ -300,22 +300,7 @@ def main(cfg: EvalWFSTNGramConfig): ) else: - @contextlib.contextmanager - def default_autocast(): - yield - - if cfg.use_amp: - if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP is enabled!\n") - autocast = torch.cuda.amp.autocast - - else: - autocast = default_autocast - else: - - autocast = default_autocast - - with autocast(): + with torch.amp.autocast(asr_model.device.type, enabled=cfg.use_amp): with torch.no_grad(): if isinstance(asr_model, EncDecHybridRNNTCTCModel): asr_model.cur_decoder = 'ctc' diff --git a/scripts/export.py b/scripts/export.py index dbe5b2b7fe2b..acfd3e3e3450 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -48,7 +48,8 @@ def get_args(argv): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=f"Export NeMo models to ONNX/Torchscript", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Export NeMo models to ONNX/Torchscript", ) parser.add_argument("source", help="Source .nemo file") parser.add_argument("out", help="Location to write result to") @@ -154,11 +155,8 @@ def nemo_export(argv): kv[k] = v model.set_export_config(kv) - autocast = nullcontext - if args.autocast: - autocast = torch.cuda.amp.autocast try: - with autocast(), torch.no_grad(), torch.inference_mode(): + with torch.amp.autocast(args.device, enabled=args.autocast), torch.no_grad(), torch.inference_mode(): model.to(device=args.device).freeze() model.eval() input_example = None diff --git a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py index 0c119b02ff7b..9c42ef6cca5b 100644 --- a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py +++ b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py @@ -209,12 +209,6 @@ def main(cfg: ConfidenceBenchmarkingConfig): filepaths.append(str(audio_file.absolute())) reference_texts.append(item['text']) - # setup AMP (optional) - autocast = None - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - # do grid-based benchmarking if grid_params is provided, otherwise a regular one work_dir = Path(cfg.output_dir) os.makedirs(work_dir, exist_ok=True) @@ -275,7 +269,7 @@ def main(cfg: ConfidenceBenchmarkingConfig): cfg.batch_size, cfg.num_workers, plot_dir, - autocast, + cfg.amp, ) for level, result in results.items(): f.write(f"{model_typename},{','.join(param_list)},{level},{','.join([str(r) for r in result])}\n") @@ -303,7 +297,7 @@ def main(cfg: ConfidenceBenchmarkingConfig): filepaths, reference_texts, plot_dir, - autocast, + cfg.amp, ) for level, result in results.items(): f.write(f"{model_typename},{','.join(param_list)},{level},{','.join([str(r) for r in result])}\n")