diff --git a/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py b/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py index da5f94d2cbcc..0beab5f54cb1 100644 --- a/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py @@ -797,19 +797,14 @@ def _riva_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNb A list of WfstNbestHypothesis objects, one for each sequence in the batch. """ if self.riva_decoder is None: - try: - from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda - except (ImportError, ModuleNotFoundError) as e: - logging.warning( - "Problem loading Riva decoder. Please (re-) install using `pip install riva-asrlib-decoder`." - ) - raise e - lm_fst = self._prepare_decoding_lm_wfst() if self.open_vocabulary_decoding and self._tokenword_disambig_id == -1: + # trying to extract tokenword_disambig_id from the lm_fst if isinstance(lm_fst, str): - import kaldifst + # use importer instead of direct import to possibly get an installation message + from nemo.collections.asr.parts.utils.wfst_utils import kaldifst_importer + kaldifst = kaldifst_importer() lm_fst = kaldifst.StdVectorFst.read(self.wfst_lm_path) tokenword_disambig_id = lm_fst.output_symbols.find("#1") if tokenword_disambig_id == -1: diff --git a/nemo/collections/asr/parts/submodules/wfst_decoder.py b/nemo/collections/asr/parts/submodules/wfst_decoder.py index 0435bdc962ed..59afac7b6c4f 100644 --- a/nemo/collections/asr/parts/submodules/wfst_decoder.py +++ b/nemo/collections/asr/parts/submodules/wfst_decoder.py @@ -24,7 +24,22 @@ from jiwer import wer as word_error_rate from omegaconf import DictConfig -from nemo.collections.asr.parts.utils.wfst_utils import TW_BREAK +from nemo.collections.asr.parts.utils.wfst_utils import kaldifst_importer, TW_BREAK + + +RIVA_DECODER_INSTALLATION_MESSAGE = ( + "riva decoder is not installed or is installed incorrectly.\n" + "please run `bash scripts/installers/install_riva_decoder.sh` or `pip install riva-asrlib-decoder` to install." +) + + +def riva_decoder_importer(): + """Import helper function that returns Riva asrlib decoder package or raises ImportError exception.""" + try: + import riva.asrlib.decoder.python_decoder as riva_decoder + except (ImportError, ModuleNotFoundError): + raise ImportError(RIVA_DECODER_INSTALLATION_MESSAGE) + return riva_decoder def _riva_config_to_dict(conf: Any) -> Dict[str, Any]: @@ -76,9 +91,9 @@ class RivaDecoderConfig(DictConfig): def __init__(self): try: - from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCudaConfig + riva_decoder = riva_decoder_importer() - config = BatchedMappedDecoderCudaConfig() + config = riva_decoder.BatchedMappedDecoderCudaConfig() config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 config.n_input_per_chunk = 50 config.online_opts.decoder_opts.default_beam = 20.0 @@ -90,7 +105,7 @@ def __init__(self): config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0 content = _riva_config_to_dict(config) - except (ImportError, ModuleNotFoundError): + except ImportError: content = {} super().__init__(content) @@ -456,6 +471,12 @@ def __init__( def _set_decoder_config(self, config: Optional['RivaDecoderConfig'] = None): if config is None or len(config) == 0: config = RivaDecoderConfig() + if not hasattr(config, "online_opts"): + # most likely empty config + # call importer to raise the exception + installation message + riva_decoder_importer() + # just in case + raise RuntimeError("Unexpected config error. Please debug manually.") config.online_opts.decoder_opts.lattice_beam = self._beam_size config.online_opts.lattice_postprocessor_opts.lm_scale = ( self._lm_weight * config.online_opts.lattice_postprocessor_opts.acoustic_scale @@ -464,8 +485,10 @@ def _set_decoder_config(self, config: Optional['RivaDecoderConfig'] = None): self._config = config def _init_decoder(self): - import kaldifst - from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig + + # use importers instead of direct import to possibly get an installation message + kaldifst = kaldifst_importer() + riva_decoder = riva_decoder_importer() from nemo.collections.asr.parts.utils.wfst_utils import load_word_lattice @@ -510,9 +533,11 @@ def _init_decoder(self): self._token2id[k] = v with tempfile.NamedTemporaryFile(mode='w+t') as words_tmp: tmp_fst.output_symbols.write_text(words_tmp.name) - config = BatchedMappedDecoderCudaConfig() + config = riva_decoder.BatchedMappedDecoderCudaConfig() _fill_inner_riva_config_(config, self._config) - self._decoder = BatchedMappedDecoderCuda(config, lm_fst, words_tmp.name, num_tokens_with_blank) + self._decoder = riva_decoder.BatchedMappedDecoderCuda( + config, lm_fst, words_tmp.name, num_tokens_with_blank + ) if tmp_fst_file: tmp_fst_file.close() @@ -759,6 +784,7 @@ def _release_gpu_memory(self): try: del self._decoder except Exception: + # apparently self._decoder was previously deleted, do nothing pass gc.collect() diff --git a/nemo/collections/asr/parts/utils/wfst_utils.py b/nemo/collections/asr/parts/utils/wfst_utils.py index c6bf604697a1..84bcc522c760 100644 --- a/nemo/collections/asr/parts/utils/wfst_utils.py +++ b/nemo/collections/asr/parts/utils/wfst_utils.py @@ -28,11 +28,15 @@ TW_BREAK = "‡" -# almost every function/method uses kaldifst try: import kaldifst -except (ImportError, ModuleNotFoundError): - raise ImportError("kaldifst is not installed.\n" "please run `pip install kaldifst` to install.") + + # check that kaldifst package is not empty + # Note: pytorch_lightning.utilities.imports.package_available may not help here + kaldifst.StdVectorFst() + _KALDIFST_AVAILABLE = True +except (ImportError, ModuleNotFoundError, AttributeError): + _KALDIFST_AVAILABLE = False try: @@ -51,16 +55,55 @@ _KALDILM_AVAILABLE = False +KALDIFST_INSTALLATION_MESSAGE = ( + "kaldifst is not installed or is installed incorrectly.\n" + "please run `pip install kaldifst` or `bash scripts/installers/install_riva_decoder.sh` to install." +) + + +GRAPHVIZ_INSTALLATION_MESSAGE = ( + "graphviz is not installed.\n" + "please run `bash scripts/installers/install_graphviz.sh` to install." +) + + +KALDILM_INSTALLATION_MESSAGE = ( + "kaldilm is not installed.\n" + "please run `pip install kaldilm` or `bash scripts/installers/install_riva_decoder.sh` to install." +) + + +def _kaldifst_maybe_raise(): + if _KALDIFST_AVAILABLE is False: + raise ImportError(KALDIFST_INSTALLATION_MESSAGE) + + +def kaldifst_importer(): + """Import helper function that returns kaldifst package or raises ImportError exception.""" + _kaldifst_maybe_raise() + return kaldifst + + def _graphviz_maybe_raise(): if _GRAPHVIZ_AVAILABLE is False: - raise ImportError( - "graphviz is not installed.\n" "please run `bash scripts/installers/install_graphviz.sh` to install." - ) + raise ImportError(GRAPHVIZ_INSTALLATION_MESSAGE) + + +def graphviz_importer(): + """Import helper function that returns graphviz package or raises ImportError exception.""" + _graphviz_maybe_raise() + return graphviz def _kaldilm_maybe_raise(): if _KALDILM_AVAILABLE is False: - raise ImportError("kaldilm is not installed.\n" "please run `pip install kaldilm` to install.") + raise ImportError(KALDILM_INSTALLATION_MESSAGE) + + +def kaldilm_importer(): + """Import helper function that returns kaldifst package or raises ImportError exception.""" + _kaldilm_maybe_raise() + return kaldilm @dataclass @@ -172,6 +215,7 @@ def arpa2fst(lm_path: str, attach_symbol_table: bool = True) -> 'kaldifst.StdVec Returns: Kaldi-type grammar WFST. """ + _kaldifst_maybe_raise() _kaldilm_maybe_raise() with tempfile.TemporaryDirectory() as tempdirname: @@ -230,6 +274,8 @@ def add_tokenwords_( Returns: The id of the tokenword disambiguation token. """ + _kaldifst_maybe_raise() + unigram_state = 0 # check if 0 is the unigram state (has no outgoing epsilon arcs) assert kaldifst.ArcIterator(g_fst, unigram_state).value.ilabel not in (0, g_fst.output_symbols.find("#0")) @@ -458,6 +504,8 @@ def make_lexicon_fst_no_silence( Returns: Kaldi-type lexicon WFST. """ + _kaldifst_maybe_raise() + backoff_disambig = "#0" tokenword_disambig = "#1" tokenword_mode = tokenword_disambig in lexicon.word2id @@ -605,6 +653,8 @@ def build_topo( Returns: Kaldi-type topology WFST. """ + _kaldifst_maybe_raise() + if name == "default": fst = build_default_topo(token2id, with_self_loops) elif name == "compact": @@ -625,6 +675,8 @@ def build_topo( def build_default_topo(token2id: Dict[str, int], with_self_loops: bool = True) -> 'kaldifst.StdVectorFst': """Build the default (correct) CTC topology.""" + _kaldifst_maybe_raise() + disambig_pattern = re.compile(r"^#\d+$") blank_id = token2id[""] fst = kaldifst.StdVectorFst() @@ -711,6 +763,8 @@ def build_default_topo(token2id: Dict[str, int], with_self_loops: bool = True) - def build_compact_topo(token2id: Dict[str, int], with_self_loops: bool = True) -> 'kaldifst.StdVectorFst': """Build the Compact CTC topology.""" + _kaldifst_maybe_raise() + disambig_pattern = re.compile(r"^#\d+$") blank_id = token2id[""] fst = kaldifst.StdVectorFst() @@ -776,6 +830,8 @@ def build_compact_topo(token2id: Dict[str, int], with_self_loops: bool = True) - def build_minimal_topo(token2id: Dict[str, int]) -> 'kaldifst.StdVectorFst': """Build the Minimal CTC topology.""" + _kaldifst_maybe_raise() + disambig_pattern = re.compile(r"^#\d+$") blank_id = token2id[""] fst = kaldifst.StdVectorFst() @@ -858,6 +914,8 @@ def mkgraph_ctc_ov( Returns: A pair of kaldi- or k2-type decoding WFST and its id of the tokenword disambiguation token. """ + _kaldifst_maybe_raise() + logging.info("Compiling G.fst ...") G = arpa2fst(lm_path) if open_vocabulary: @@ -1045,6 +1103,8 @@ def __init__( symbol_table: Optional[Dict[int, str]] = None, auxiliary_tables: Optional[Dict[str, Any]] = None, ): + _kaldifst_maybe_raise() + if not isinstance(lattice, kaldifst.Lattice): raise ValueError(f"Wrong lattice type: `{type(lattice)}`") super().__init__(lattice) @@ -1143,6 +1203,8 @@ def edit_distance(self, reference_sequence: List[int]) -> int: Returns: Number of edits. """ + _kaldifst_maybe_raise() + if not self.properties.InputEpsilonFree: logging.warning(f"Lattice contains input epsilons. Edit distance calculations may not be accurate.") if not all(reference_sequence): @@ -1180,6 +1242,7 @@ def draw( Returns: graphviz.Digraph or IPython.display.HTML """ + _kaldifst_maybe_raise() _graphviz_maybe_raise() isym, osym = None, None @@ -1286,6 +1349,8 @@ def levenshtein_graph_kaldi( Returns: Kaldi-type levenshtein WFST. """ + _kaldifst_maybe_raise() + if fst.properties(KaldiFstMask.Acceptor.value, True) != KaldiFstMask.Acceptor.value: logging.warning( "Levenshtein graph construction is not safe for WFSTs with different input and output symbols." @@ -1349,6 +1414,8 @@ def load_word_lattice( Returns: Dictionary with lattice names and corresponding lattices in KaldiWordLattice format. """ + _kaldifst_maybe_raise() + lattice_dict = {} lattice = None max_state = 0 diff --git a/nemo/core/utils/k2_utils.py b/nemo/core/utils/k2_utils.py index 3dff6a35d3e3..3e7c2a6f5a70 100644 --- a/nemo/core/utils/k2_utils.py +++ b/nemo/core/utils/k2_utils.py @@ -16,7 +16,7 @@ K2_INSTALLATION_MESSAGE = ( "Could not import `k2`.\n" "Please install k2 in one of the following ways:\n" - "1) (recommended) Run `bash scripts/speech_recognition/k2/setup.sh`\n" + "1) (recommended) Run `bash scripts/installers/install_k2.sh`\n" "2) Use any approach from https://k2-fsa.github.io/k2/installation/index.html " "if your your cuda and pytorch versions are supported.\n" "It is advised to always install k2 using setup.sh only, " diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index 1b63e1dd0fa1..7745f5326047 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -4,9 +4,7 @@ einops g2p_en jiwer kaldi-python-io -kaldifst kaldiio -kaldilm lhotse>=1.24.2 librosa>=0.10.0 marshmallow @@ -16,7 +14,6 @@ pyannote.metrics pydub pyloudnorm resampy -riva-asrlib-decoder ruamel.yaml scipy>=0.14 soundfile diff --git a/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py b/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py index b8aea6435159..a1db7cec4f23 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py @@ -188,8 +188,8 @@ def beam_search_eval( packed_batch = torch.zeros(len(probs_batch), max(probs_lens), probs_batch[0].shape[-1], device='cpu') for prob_index in range(len(probs_batch)): - packed_batch[prob_index, : probs_lens[prob_index], :] = torch.tensor( - probs_batch[prob_index], device=packed_batch.device, dtype=packed_batch.dtype + packed_batch[prob_index, : probs_lens[prob_index], :] = probs_batch[prob_index].to( + device=packed_batch.device, dtype=packed_batch.dtype ) _, beams_batch = decoding.ctc_decoder_predictions_tensor( @@ -228,7 +228,7 @@ def beam_search_eval( score = candidate.score if preds_output_file: - out_file.write('{}\t{}\n'.format(pred_text, score)) + out_file.write(f'{pred_text}\t{score}\n') wer_dist_best += wer_dist_min cer_dist_best += cer_dist_min sample_idx += len(probs_batch) @@ -337,7 +337,7 @@ def default_autocast(): chars_count = 0 for batch_idx, probs in enumerate(all_probs): preds = np.argmax(probs, axis=1) - preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0) + preds_tensor = preds.to(device='cpu').unsqueeze(0) preds_lens = torch.tensor([preds_tensor.shape[1]], device='cpu') if isinstance(asr_model, EncDecHybridRNNTCTCModel): pred_text = asr_model.ctc_decoding.ctc_decoder_predictions_tensor(preds_tensor, preds_lens)[0][0] @@ -375,8 +375,6 @@ def default_autocast(): f"Could not find both the ARPA model file `{cfg.arpa_model_file}` " f"and the decoding WFST file `{cfg.decoding_wfst_file}`." ) - lm_path = cfg.arpa_model_file - wfst_path = cfg.decoding_wfst_file if cfg.beam_width is None or cfg.lm_weight is None: raise ValueError("beam_width and lm_weight are needed to perform WFST decoding.") diff --git a/scripts/installers/install_riva_decoder.sh b/scripts/installers/install_riva_decoder.sh new file mode 100755 index 000000000000..4e6e99b570ab --- /dev/null +++ b/scripts/installers/install_riva_decoder.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pip install kaldifst kaldilm riva-asrlib-decoder