Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/gigaam.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:

- name: Install Python dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install --upgrade pip wheel setuptools
pip install --no-cache-dir torch==2.8.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cpu
pip install --no-cache-dir -e ".[longform,tests]"

Expand Down Expand Up @@ -134,15 +134,15 @@ jobs:

- name: Check code formatting with black
run: |
black --check --diff gigaam/ tests/
black --check --diff gigaam/ tests/ triton_scripts/

- name: Check imports with isort
run: |
isort --check-only --diff gigaam/ tests/
isort --check-only --diff gigaam/ tests/ triton_scripts/

- name: Lint with flake8
run: |
flake8 --ignore=E203,W503,W504 --max-line-length=120 --statistics gigaam/ tests/
flake8 --ignore=E203,W503,W504 --max-line-length=120 --statistics gigaam/ tests/ triton_scripts/

- name: Type check with mypy
run: |
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ model = AutoModel.from_pretrained("ai-sage/GigaAM-v3", revision="e2e_rnnt", trus

These and more advanced (e.g. custom audio loading, batching) examples can be found in the [Colab notebook](https://colab.research.google.com/github/salute-developers/GigaAM/blob/main/colab_example.ipynb).

### Triton Inference Server and TensorRT

All speech recognition models can also be used in a server environment in ONNX/TRT format through Triton Inference Server. For setup instructions, model conversion, and deployment details, see the [Triton Inference Server documentation](./triton_scripts/README.md).

---

## Citation
Expand Down
4 changes: 4 additions & 0 deletions README_ru.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ model = AutoModel.from_pretrained("ai-sage/GigaAM-v3", revision="e2e_rnnt", trus

Эти и более продвинутые примеры (кастомная загрузка аудио, батчинг) доступны в [Colab notebook](https://colab.research.google.com/github/salute-developers/GigaAM/blob/main/colab_example.ipynb).

### Triton Inference Server и TensorRT

Все модели распознавания речи также можно использовать в серверном окружении в формате ONNX/TRT через Triton Inference Server. Инструкции по настройке, конвертации моделей и развёртыванию описаны в [документации Triton Inference Server](./triton_scripts/README.md).

---

## Citation
Expand Down
22 changes: 11 additions & 11 deletions gigaam/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,23 @@ def decode(
head: "CTCHead",
encoded: Tensor,
lengths: Tensor,
labels: Optional[Tensor] = None,
) -> List[Tuple[List[int], List[int]]]:
"""
CTC greedy decode: returns (token_ids, token_frames) per sample.
Token frames are time indices (0..T-1) where a token is emitted.
If labels are provided, encoded and head are not used.
"""
log_probs = head(encoder_output=encoded)
assert (
log_probs.ndim == 3
), f"Expected log_probs [B,T,C], got {tuple(log_probs.shape)}"
B, T, C = log_probs.shape
assert (
C == len(self.tokenizer) + 1
), f"Num classes {C} != len(vocab)+1 {len(self.tokenizer)+1}"

labels = log_probs.argmax(dim=-1)
if labels is None:
log_probs = head(encoder_output=encoded)
C = log_probs.shape[-1]
assert (
C == len(self.tokenizer) + 1
), f"Num classes {C} != len(vocab)+1 {len(self.tokenizer)+1}"
labels = log_probs.argmax(dim=-1)

B, T = labels.shape
device = labels.device

lengths = lengths.to(device=device).clamp(min=0, max=T)

skip_mask = labels != self.blank_id
Expand Down
23 changes: 20 additions & 3 deletions gigaam/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -514,12 +515,13 @@ def __init__(

def input_example(
self,
batch_size: int = 1,
batch_size: int = 8,
seqlen: int = 200,
) -> Tuple[Tensor, Tensor]:
device = next(self.parameters()).device
features = torch.zeros(batch_size, self.feat_in, seqlen)
feature_lengths = torch.full([batch_size], features.shape[-1])
features = torch.randn(batch_size, self.feat_in, seqlen)
feature_lengths = torch.randint(1, seqlen + 1, (batch_size,))
feature_lengths[0] = seqlen
return features.float().to(device), feature_lengths.to(device)

def input_names(self) -> List[str]:
Expand All @@ -528,6 +530,21 @@ def input_names(self) -> List[str]:
def output_names(self) -> List[str]:
return ["encoded", "encoded_len"]

@contextmanager
def onnx_export_mode(self):
saved = []
for layer in self.layers:
attn = layer.self_attn
saved.append((attn.flash_attn, attn.torch_sdpa_attn))
attn.flash_attn = False
attn.torch_sdpa_attn = False
try:
yield
finally:
for layer, (fa, sdpa) in zip(self.layers, saved):
layer.self_attn.flash_attn = fa
layer.self_attn.torch_sdpa_attn = sdpa

def dynamic_axes(self) -> Dict[str, Dict[int, str]]:
return {
"audio_signal": {0: "batch_size", 2: "seq_len"},
Expand Down
3 changes: 2 additions & 1 deletion gigaam/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def to_onnx(self, dir_path: str = ".") -> None:
"""
Export onnx model encoder to the specified dir.
"""
self._to_onnx(dir_path)
with self.encoder.onnx_export_mode():
self._to_onnx(dir_path)
omegaconf.OmegaConf.save(self.cfg, f"{dir_path}/{self.cfg.model_name}.yaml")

def _to_onnx(self, dir_path: str = ".") -> None:
Expand Down
50 changes: 29 additions & 21 deletions gigaam/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,46 @@


def infer_onnx(
wav_file: str,
wav_file: Optional[str],
model_cfg: omegaconf.DictConfig,
sessions: List[rt.InferenceSession],
sessions: List[Optional[rt.InferenceSession]],
enc_features: Optional[np.ndarray] = None,
preprocessor: Optional[FeatureExtractor] = None,
tokenizer: Optional[Tokenizer] = None,
) -> Union[str, np.ndarray]:
"""Run ONNX sessions for the model, requires preprocessor instantiating"""
"""
Run ONNX sessions for the model, requires preprocessor instantiating.
The first session (the encoder one) and wav_file can be None if enc_features is provided.
"""
model_name = model_cfg.model_name

if preprocessor is None:
assert (
enc_features is not None or sessions[0] is not None
), "At least one of encoder session or enc_features is required"

if preprocessor is None and enc_features is None:
preprocessor = hydra.utils.instantiate(model_cfg.preprocessor)
if tokenizer is None and ("ctc" in model_name or "rnnt" in model_name):
tokenizer = hydra.utils.instantiate(model_cfg.decoding).tokenizer

sgn = load_audio(wav_file)
input_signal = (
preprocessor(sgn.unsqueeze(0), torch.tensor([sgn.shape[-1]]))[0]
.detach()
.numpy()
)

enc_sess = sessions[0]
enc_inputs = {
node.name: data
for (node, data) in zip(
enc_sess.get_inputs(),
[input_signal.astype(DTYPE), [input_signal.shape[-1]]],
if enc_features is None:
sgn = load_audio(wav_file)
input_signal = (
preprocessor(sgn.unsqueeze(0), torch.tensor([sgn.shape[-1]]))[0]
.detach()
.numpy()
)
}
enc_features = enc_sess.run(
[node.name for node in enc_sess.get_outputs()], enc_inputs
)[0]
enc_sess = sessions[0]
enc_inputs = {
node.name: data
for (node, data) in zip(
enc_sess.get_inputs(),
[input_signal.astype(DTYPE), [input_signal.shape[-1]]],
)
}
enc_features = enc_sess.run(
[node.name for node in enc_sess.get_outputs()], enc_inputs
)[0]

if "emo" in model_name or "ssl" in model_name:
return enc_features
Expand Down
10 changes: 10 additions & 0 deletions triton_scripts/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM nvcr.io/nvidia/tritonserver:24.10-py3

RUN pip install --no-cache-dir \
"torch>=2.6,<2.11" \
"torchaudio>=2.6,<2.11" \
sentencepiece \
omegaconf \
onnxruntime-gpu \
tqdm \
hydra-core
82 changes: 82 additions & 0 deletions triton_scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Triton Inference Server Setup

This setup supports all ASR models from the GigaAM family. Inference is implemented through a Triton ensemble: the client sends WAV files and receives transcribed texts. CTC models are converted to ONNX/TRT entirely, while RNNT models are split into encoder (ONNX/TRT) and decoder/joint components that run in Python using onnxruntime.

## Prerequisites

Navigate to the triton_scripts directory:
```bash
cd triton_scripts
```

## 0. Build Docker Image

Build the Triton Inference Server Docker image:
```bash
docker build -t gigaam-triton .
```

## 1. Convert Models to ONNX

Convert models to ONNX format. This creates `.onnx` checkpoints and configs:
```bash
python run_convert_onnx.py <model_version> # e.g., v3_ctc, v3_e2e_rnnt
```

**Note:** The script saves model configs to the preprocessing directory. For `v3` family models, preprocessing differs from earlier versions. Since Triton uses a shared preprocessing model, you can only use models with the same preprocessing simultaneously (either all `v3` models or all earlier models). The preprocessing is determined by the last model converted to ONNX.

## 2. Convert ONNX to TensorRT

Convert ONNX models to TensorRT format. This converts the version of the corresponding CTC/RNNT model that was last converted to ONNX. Run inside the TensorRT Docker container:
```bash
docker run --gpus all -it --rm -v $(pwd):/workspace nvcr.io/nvidia/tensorrt:24.10-py3
# inside the container:
bash run_convert_trt.sh <ctc | rnnt>
```

## 3. Start Triton Server

Run the Triton Inference Server:
```bash
docker run --gpus all --ipc=host -p 8000:8000 -p 8001:8001 -p 8002:8002 \
-v "$(pwd)/repos:/models" \
-v "$(pwd)/..:/opt/gigaam_repo" \
-e PYTHONPATH=/opt/gigaam_repo \
gigaam-triton \
tritonserver --model-repository=/models --exit-on-error=false
```

Python backend models (e.g. [`rnnt_postprocessing`](repos/rnnt_postprocessing/1/model.py)) do `import gigaam`. The package lives next to this directory, under `gigaam_repo/gigaam/`.

**Note:** For ONNX models, the default configuration uses `instance_group [{ kind: KIND_GPU }]`. To enable CPU execution, update the `instance_group` to `KIND_CPU` in the following model config files [`ctc`](repos/ctc_encoder_onnx/config.pbtxt), [`rnnt`](repos/gigaam_encoder_onnx/config.pbtxt).

## 4. Run Client

Run inference using the client:
```bash
python run_client.py <model_type> <backend> <wav_file1> [wav_file2] ...
```

Arguments:
- `model_type`: `ctc` or `rnnt`
- `backend`: `onnx` or `trt`
- `wav_file1`, `wav_file2`, ...: Paths to WAV files

Examples:
```bash
python run_client.py rnnt onnx example.wav
python run_client.py ctc trt audio1.wav audio2.wav audio3.wav
```

## Benchmark

Forward pass time in seconds on CUDA for the first 4 segments from `long_example.wav` (VAD-segmented, ~65s total audio). For torch/onnx — both single-sample and batched inference are shown.

| Backend | v3_ctc | v3_e2e_rnnt |
|:--------------|:----------------|:----------------|
| triton/trt | 0.034 ± 0.000 | 0.403 ± 0.008 |
| triton/onnx | 0.046 ± 0.001 | 0.413 ± 0.005 |
| onnx (batch) | 0.037 ± 0.004 | 0.949 ± 0.017 |
| onnx | 0.047 ± 0.001 | 1.093 ± 0.045 |
| torch (batch) | 0.036 ± 0.002 | 0.919 ± 0.002 |
| torch | 0.112 ± 0.003 | 1.008 ± 0.001 |
33 changes: 33 additions & 0 deletions triton_scripts/repos/ctc_encoder_onnx/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: "ctc_encoder_onnx"
platform: "onnxruntime_onnx"
max_batch_size: 0

input [
{
name: "features"
data_type: TYPE_FP32
dims: [-1, 64, -1]
},
{
name: "feature_lengths"
data_type: TYPE_INT64
dims: [-1]
}
]

output [
{
name: "token_ids"
data_type: TYPE_INT64
dims: [-1, -1]
},
{
name: "token_ids_lengths"
data_type: TYPE_INT64
dims: [-1]
}
]

instance_group [{ kind: KIND_GPU }]

parameters { key: "cudnn_conv_algo_search" value: { string_value: "1"} }
31 changes: 31 additions & 0 deletions triton_scripts/repos/ctc_encoder_trt/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: "ctc_encoder_trt"
platform: "tensorrt_plan"
max_batch_size: 0

input [
{
name: "features"
data_type: TYPE_FP32
dims: [-1, 64, -1]
},
{
name: "feature_lengths"
data_type: TYPE_INT64
dims: [-1]
}
]

output [
{
name: "token_ids"
data_type: TYPE_INT64
dims: [-1, -1]
},
{
name: "token_ids_lengths"
data_type: TYPE_INT64
dims: [-1]
}
]

instance_group [{ kind: KIND_GPU }]
Loading
Loading