Skip to content

Commit

Permalink
isolate decoder components installation and fix suggestions
Browse files Browse the repository at this point in the history
Signed-off-by: Aleksandr Laptev <[email protected]>
  • Loading branch information
GNroy committed Aug 2, 2024
1 parent f788c91 commit 9826eb6
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 34 deletions.
13 changes: 4 additions & 9 deletions nemo/collections/asr/parts/submodules/ctc_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,19 +797,14 @@ def _riva_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNb
A list of WfstNbestHypothesis objects, one for each sequence in the batch.
"""
if self.riva_decoder is None:
try:
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda
except (ImportError, ModuleNotFoundError) as e:
logging.warning(
"Problem loading Riva decoder. Please (re-) install using `pip install riva-asrlib-decoder`."
)
raise e

lm_fst = self._prepare_decoding_lm_wfst()
if self.open_vocabulary_decoding and self._tokenword_disambig_id == -1:
# trying to extract tokenword_disambig_id from the lm_fst
if isinstance(lm_fst, str):
import kaldifst
# use importer instead of direct import to possibly get an installation message
from nemo.collections.asr.parts.utils.wfst_utils import kaldifst_importer

kaldifst = kaldifst_importer()
lm_fst = kaldifst.StdVectorFst.read(self.wfst_lm_path)
tokenword_disambig_id = lm_fst.output_symbols.find("#1")
if tokenword_disambig_id == -1:
Expand Down
42 changes: 34 additions & 8 deletions nemo/collections/asr/parts/submodules/wfst_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,22 @@
from jiwer import wer as word_error_rate
from omegaconf import DictConfig

from nemo.collections.asr.parts.utils.wfst_utils import TW_BREAK
from nemo.collections.asr.parts.utils.wfst_utils import kaldifst_importer, TW_BREAK


RIVA_DECODER_INSTALLATION_MESSAGE = (
"riva decoder is not installed or is installed incorrectly.\n"
"please run `bash scripts/installers/install_riva_decoder.sh` or `pip install riva-asrlib-decoder` to install."
)


def riva_decoder_importer():
"""Import helper function that returns Riva asrlib decoder package or raises ImportError exception."""
try:
import riva.asrlib.decoder.python_decoder as riva_decoder
except (ImportError, ModuleNotFoundError):
raise ImportError(RIVA_DECODER_INSTALLATION_MESSAGE)
return riva_decoder


def _riva_config_to_dict(conf: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -76,9 +91,9 @@ class RivaDecoderConfig(DictConfig):

def __init__(self):
try:
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCudaConfig
riva_decoder = riva_decoder_importer()

config = BatchedMappedDecoderCudaConfig()
config = riva_decoder.BatchedMappedDecoderCudaConfig()
config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0
config.n_input_per_chunk = 50
config.online_opts.decoder_opts.default_beam = 20.0
Expand All @@ -90,7 +105,7 @@ def __init__(self):
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0

content = _riva_config_to_dict(config)
except (ImportError, ModuleNotFoundError):
except ImportError:
content = {}
super().__init__(content)

Expand Down Expand Up @@ -456,6 +471,12 @@ def __init__(
def _set_decoder_config(self, config: Optional['RivaDecoderConfig'] = None):
if config is None or len(config) == 0:
config = RivaDecoderConfig()
if not hasattr(config, "online_opts"):
# most likely empty config
# call importer to raise the exception + installation message
riva_decoder_importer()
# just in case
raise RuntimeError("Unexpected config error. Please debug manually.")
config.online_opts.decoder_opts.lattice_beam = self._beam_size
config.online_opts.lattice_postprocessor_opts.lm_scale = (
self._lm_weight * config.online_opts.lattice_postprocessor_opts.acoustic_scale
Expand All @@ -464,8 +485,10 @@ def _set_decoder_config(self, config: Optional['RivaDecoderConfig'] = None):
self._config = config

def _init_decoder(self):
import kaldifst
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig

# use importers instead of direct import to possibly get an installation message
kaldifst = kaldifst_importer()
riva_decoder = riva_decoder_importer()

from nemo.collections.asr.parts.utils.wfst_utils import load_word_lattice

Expand Down Expand Up @@ -510,9 +533,11 @@ def _init_decoder(self):
self._token2id[k] = v
with tempfile.NamedTemporaryFile(mode='w+t') as words_tmp:
tmp_fst.output_symbols.write_text(words_tmp.name)
config = BatchedMappedDecoderCudaConfig()
config = riva_decoder.BatchedMappedDecoderCudaConfig()
_fill_inner_riva_config_(config, self._config)
self._decoder = BatchedMappedDecoderCuda(config, lm_fst, words_tmp.name, num_tokens_with_blank)
self._decoder = riva_decoder.BatchedMappedDecoderCuda(
config, lm_fst, words_tmp.name, num_tokens_with_blank
)
if tmp_fst_file:
tmp_fst_file.close()

Expand Down Expand Up @@ -759,6 +784,7 @@ def _release_gpu_memory(self):
try:
del self._decoder
except Exception:
# apparently self._decoder was previously deleted, do nothing
pass
gc.collect()

Expand Down
81 changes: 74 additions & 7 deletions nemo/collections/asr/parts/utils/wfst_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
TW_BREAK = "‡"


# almost every function/method uses kaldifst
try:
import kaldifst
except (ImportError, ModuleNotFoundError):
raise ImportError("kaldifst is not installed.\n" "please run `pip install kaldifst` to install.")

# check that kaldifst package is not empty
# Note: pytorch_lightning.utilities.imports.package_available may not help here
kaldifst.StdVectorFst()
_KALDIFST_AVAILABLE = True
except (ImportError, ModuleNotFoundError, AttributeError):
_KALDIFST_AVAILABLE = False


try:
Expand All @@ -51,16 +55,55 @@
_KALDILM_AVAILABLE = False


KALDIFST_INSTALLATION_MESSAGE = (
"kaldifst is not installed or is installed incorrectly.\n"
"please run `pip install kaldifst` or `bash scripts/installers/install_riva_decoder.sh` to install."
)


GRAPHVIZ_INSTALLATION_MESSAGE = (
"graphviz is not installed.\n"
"please run `bash scripts/installers/install_graphviz.sh` to install."
)


KALDILM_INSTALLATION_MESSAGE = (
"kaldilm is not installed.\n"
"please run `pip install kaldilm` or `bash scripts/installers/install_riva_decoder.sh` to install."
)


def _kaldifst_maybe_raise():
if _KALDIFST_AVAILABLE is False:
raise ImportError(KALDIFST_INSTALLATION_MESSAGE)


def kaldifst_importer():
"""Import helper function that returns kaldifst package or raises ImportError exception."""
_kaldifst_maybe_raise()
return kaldifst


def _graphviz_maybe_raise():
if _GRAPHVIZ_AVAILABLE is False:
raise ImportError(
"graphviz is not installed.\n" "please run `bash scripts/installers/install_graphviz.sh` to install."
)
raise ImportError(GRAPHVIZ_INSTALLATION_MESSAGE)


def graphviz_importer():
"""Import helper function that returns graphviz package or raises ImportError exception."""
_graphviz_maybe_raise()
return graphviz


def _kaldilm_maybe_raise():
if _KALDILM_AVAILABLE is False:
raise ImportError("kaldilm is not installed.\n" "please run `pip install kaldilm` to install.")
raise ImportError(KALDILM_INSTALLATION_MESSAGE)


def kaldilm_importer():
"""Import helper function that returns kaldifst package or raises ImportError exception."""
_kaldilm_maybe_raise()
return kaldilm


@dataclass
Expand Down Expand Up @@ -172,6 +215,7 @@ def arpa2fst(lm_path: str, attach_symbol_table: bool = True) -> 'kaldifst.StdVec
Returns:
Kaldi-type grammar WFST.
"""
_kaldifst_maybe_raise()
_kaldilm_maybe_raise()

with tempfile.TemporaryDirectory() as tempdirname:
Expand Down Expand Up @@ -230,6 +274,8 @@ def add_tokenwords_(
Returns:
The id of the tokenword disambiguation token.
"""
_kaldifst_maybe_raise()

unigram_state = 0
# check if 0 is the unigram state (has no outgoing epsilon arcs)
assert kaldifst.ArcIterator(g_fst, unigram_state).value.ilabel not in (0, g_fst.output_symbols.find("#0"))
Expand Down Expand Up @@ -458,6 +504,8 @@ def make_lexicon_fst_no_silence(
Returns:
Kaldi-type lexicon WFST.
"""
_kaldifst_maybe_raise()

backoff_disambig = "#0"
tokenword_disambig = "#1"
tokenword_mode = tokenword_disambig in lexicon.word2id
Expand Down Expand Up @@ -605,6 +653,8 @@ def build_topo(
Returns:
Kaldi-type topology WFST.
"""
_kaldifst_maybe_raise()

if name == "default":
fst = build_default_topo(token2id, with_self_loops)
elif name == "compact":
Expand All @@ -625,6 +675,8 @@ def build_topo(

def build_default_topo(token2id: Dict[str, int], with_self_loops: bool = True) -> 'kaldifst.StdVectorFst':
"""Build the default (correct) CTC topology."""
_kaldifst_maybe_raise()

disambig_pattern = re.compile(r"^#\d+$")
blank_id = token2id["<blk>"]
fst = kaldifst.StdVectorFst()
Expand Down Expand Up @@ -711,6 +763,8 @@ def build_default_topo(token2id: Dict[str, int], with_self_loops: bool = True) -

def build_compact_topo(token2id: Dict[str, int], with_self_loops: bool = True) -> 'kaldifst.StdVectorFst':
"""Build the Compact CTC topology."""
_kaldifst_maybe_raise()

disambig_pattern = re.compile(r"^#\d+$")
blank_id = token2id["<blk>"]
fst = kaldifst.StdVectorFst()
Expand Down Expand Up @@ -776,6 +830,8 @@ def build_compact_topo(token2id: Dict[str, int], with_self_loops: bool = True) -

def build_minimal_topo(token2id: Dict[str, int]) -> 'kaldifst.StdVectorFst':
"""Build the Minimal CTC topology."""
_kaldifst_maybe_raise()

disambig_pattern = re.compile(r"^#\d+$")
blank_id = token2id["<blk>"]
fst = kaldifst.StdVectorFst()
Expand Down Expand Up @@ -858,6 +914,8 @@ def mkgraph_ctc_ov(
Returns:
A pair of kaldi- or k2-type decoding WFST and its id of the tokenword disambiguation token.
"""
_kaldifst_maybe_raise()

logging.info("Compiling G.fst ...")
G = arpa2fst(lm_path)
if open_vocabulary:
Expand Down Expand Up @@ -1045,6 +1103,8 @@ def __init__(
symbol_table: Optional[Dict[int, str]] = None,
auxiliary_tables: Optional[Dict[str, Any]] = None,
):
_kaldifst_maybe_raise()

if not isinstance(lattice, kaldifst.Lattice):
raise ValueError(f"Wrong lattice type: `{type(lattice)}`")
super().__init__(lattice)
Expand Down Expand Up @@ -1143,6 +1203,8 @@ def edit_distance(self, reference_sequence: List[int]) -> int:
Returns:
Number of edits.
"""
_kaldifst_maybe_raise()

if not self.properties.InputEpsilonFree:
logging.warning(f"Lattice contains input epsilons. Edit distance calculations may not be accurate.")
if not all(reference_sequence):
Expand Down Expand Up @@ -1180,6 +1242,7 @@ def draw(
Returns:
graphviz.Digraph or IPython.display.HTML
"""
_kaldifst_maybe_raise()
_graphviz_maybe_raise()

isym, osym = None, None
Expand Down Expand Up @@ -1286,6 +1349,8 @@ def levenshtein_graph_kaldi(
Returns:
Kaldi-type levenshtein WFST.
"""
_kaldifst_maybe_raise()

if fst.properties(KaldiFstMask.Acceptor.value, True) != KaldiFstMask.Acceptor.value:
logging.warning(
"Levenshtein graph construction is not safe for WFSTs with different input and output symbols."
Expand Down Expand Up @@ -1349,6 +1414,8 @@ def load_word_lattice(
Returns:
Dictionary with lattice names and corresponding lattices in KaldiWordLattice format.
"""
_kaldifst_maybe_raise()

lattice_dict = {}
lattice = None
max_state = 0
Expand Down
2 changes: 1 addition & 1 deletion nemo/core/utils/k2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
K2_INSTALLATION_MESSAGE = (
"Could not import `k2`.\n"
"Please install k2 in one of the following ways:\n"
"1) (recommended) Run `bash scripts/speech_recognition/k2/setup.sh`\n"
"1) (recommended) Run `bash scripts/installers/install_k2.sh`\n"
"2) Use any approach from https://k2-fsa.github.io/k2/installation/index.html "
"if your your cuda and pytorch versions are supported.\n"
"It is advised to always install k2 using setup.sh only, "
Expand Down
3 changes: 0 additions & 3 deletions requirements/requirements_asr.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ einops
g2p_en
jiwer
kaldi-python-io
kaldifst
kaldiio
kaldilm
lhotse>=1.24.2
librosa>=0.10.0
marshmallow
Expand All @@ -16,7 +14,6 @@ pyannote.metrics
pydub
pyloudnorm
resampy
riva-asrlib-decoder
ruamel.yaml
scipy>=0.14
soundfile
Expand Down
10 changes: 4 additions & 6 deletions scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def beam_search_eval(
packed_batch = torch.zeros(len(probs_batch), max(probs_lens), probs_batch[0].shape[-1], device='cpu')

for prob_index in range(len(probs_batch)):
packed_batch[prob_index, : probs_lens[prob_index], :] = torch.tensor(
probs_batch[prob_index], device=packed_batch.device, dtype=packed_batch.dtype
packed_batch[prob_index, : probs_lens[prob_index], :] = probs_batch[prob_index].to(
device=packed_batch.device, dtype=packed_batch.dtype
)

_, beams_batch = decoding.ctc_decoder_predictions_tensor(
Expand Down Expand Up @@ -228,7 +228,7 @@ def beam_search_eval(

score = candidate.score
if preds_output_file:
out_file.write('{}\t{}\n'.format(pred_text, score))
out_file.write(f'{pred_text}\t{score}\n')
wer_dist_best += wer_dist_min
cer_dist_best += cer_dist_min
sample_idx += len(probs_batch)
Expand Down Expand Up @@ -337,7 +337,7 @@ def default_autocast():
chars_count = 0
for batch_idx, probs in enumerate(all_probs):
preds = np.argmax(probs, axis=1)
preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0)
preds_tensor = preds.to(device='cpu').unsqueeze(0)
preds_lens = torch.tensor([preds_tensor.shape[1]], device='cpu')
if isinstance(asr_model, EncDecHybridRNNTCTCModel):
pred_text = asr_model.ctc_decoding.ctc_decoder_predictions_tensor(preds_tensor, preds_lens)[0][0]
Expand Down Expand Up @@ -375,8 +375,6 @@ def default_autocast():
f"Could not find both the ARPA model file `{cfg.arpa_model_file}` "
f"and the decoding WFST file `{cfg.decoding_wfst_file}`."
)
lm_path = cfg.arpa_model_file
wfst_path = cfg.decoding_wfst_file

if cfg.beam_width is None or cfg.lm_weight is None:
raise ValueError("beam_width and lm_weight are needed to perform WFST decoding.")
Expand Down
17 changes: 17 additions & 0 deletions scripts/installers/install_riva_decoder.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

pip install kaldifst kaldilm riva-asrlib-decoder

0 comments on commit 9826eb6

Please sign in to comment.