diff --git a/nemo_skills/inference/model/__init__.py b/nemo_skills/inference/model/__init__.py index 847bd8eed5..dc5d2d3e53 100644 --- a/nemo_skills/inference/model/__init__.py +++ b/nemo_skills/inference/model/__init__.py @@ -20,6 +20,9 @@ # NIM models (speech) from .asr_nim import ASRNIMModel +# NeMo models (speech) +from .nemo_asr import NemoASRModel + # Audio utilities from .audio_utils import ( audio_file_to_base64, @@ -65,6 +68,7 @@ "sglang": SGLangModel, "tts_nim": TTSNIMModel, "asr_nim": ASRNIMModel, + "nemo_asr": NemoASRModel, } diff --git a/nemo_skills/inference/model/nemo_asr.py b/nemo_skills/inference/model/nemo_asr.py new file mode 100644 index 0000000000..b947311ead --- /dev/null +++ b/nemo_skills/inference/model/nemo_asr.py @@ -0,0 +1,556 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NeMo ASR model client for connecting to serve_nemo_asr.py server. + +This client wraps the NeMo ASR server with the BaseModel API pattern, +allowing it to be used in the nemo-skills inference pipeline. +""" + +from __future__ import annotations + +import glob +import logging +import tarfile +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, Optional +from urllib.parse import urlparse, urlunparse + +import aiofiles +import httpx + +from nemo_skills.utils import get_logger_name + +from .nim_utils import setup_ssh_tunnel, validate_unsupported_params + +LOG = logging.getLogger(get_logger_name(__file__)) + + +class NemoASRModel: + """Client wrapper for NeMo ASR server. + + This model follows the BaseModel pattern and connects to a NeMo ASR server + running serve_nemo_asr.py via HTTP. + + Parameters + ---------- + host : str, default "127.0.0.1" + Hostname or IP of the NeMo ASR server. + port : str, default "5000" + HTTP port of the server. + model : str, default "nemo-asr" + Model identifier (mostly for compatibility, server uses its loaded model). + base_url : str | None + Full base URL (overrides host:port if provided). + language_code : str, default "en" + Language code for recognition. + response_format : str, default "verbose_json" + Response format: "json" or "verbose_json". + enable_timestamps : bool, default True + Whether to request word-level timestamps. + enable_audio_chunking : bool, default True + Whether to automatically chunk long audio files. + chunk_audio_threshold_sec : int, default 30 + Audio duration threshold (in seconds) for automatic chunking. + ssh_server : str | None + SSH server for tunneling (format: [user@]host). + ssh_key_path : str | None + Path to SSH key for tunneling. + tokenizer : str | None + Accepted for API compatibility; ignored by NemoASRModel. + data_dir : str, default "" + Base directory for resolving relative audio paths. + output_dir : str | None + Accepted for API compatibility; ignored by NemoASRModel. + tarred_audio_filepaths : str | list[str] | None + Optional path(s) to tar shards for NeMo tarred ASR datasets. + If configured, `prompt` can be a member filename from tarred manifest. + max_workers : int, default 64 + Maximum concurrent requests. + """ + + def __init__( + self, + host: str = "127.0.0.1", + port: str = "5000", + model: str = "nemo-asr", + *, + base_url: str | None = None, + language_code: str = "en", + response_format: str = "verbose_json", + enable_timestamps: bool = True, + enable_audio_chunking: bool = True, + chunk_audio_threshold_sec: int = 30, + ssh_server: str | None = None, + ssh_key_path: str | None = None, + tokenizer: str | None = None, + data_dir: str = "", + output_dir: str | None = None, + tarred_audio_filepaths: str | list[str] | None = None, + max_workers: int = 64, + ) -> None: + """Initialize NemoASRModel client.""" + if tokenizer is not None: + LOG.warning("NemoASRModel does not use tokenizer. Ignoring tokenizer argument.") + if output_dir is not None: + LOG.warning("NemoASRModel does not use output_dir. Ignoring output_dir argument.") + + base_path = "" + base_scheme = "http" + # Handle base_url compatibility + if base_url: + parsed = urlparse(base_url) + if parsed.scheme: + if parsed.hostname: + host = parsed.hostname + if parsed.port is not None: + port = str(parsed.port) + base_scheme = parsed.scheme + base_path = parsed.path.rstrip("/") + else: + _url = base_url.replace("http://", "").replace("https://", "") + if ":" in _url: + host, port = _url.split(":", 1) + else: + host = _url + + # Setup SSH tunnel if needed + host, port, self._tunnel = setup_ssh_tunnel(host, port, ssh_server, ssh_key_path) + + # Store attributes expected by inference code + self.model_name_or_path = model + self.server_host = host + self.server_port = port + self.data_dir = data_dir + self.output_dir = output_dir + + # ASR-specific config + self.language_code = language_code + self.response_format = response_format + self.enable_timestamps = enable_timestamps + self.enable_audio_chunking = enable_audio_chunking + self.chunk_audio_threshold_sec = chunk_audio_threshold_sec + self.tarred_audio_files = self._resolve_tarred_audio_files(tarred_audio_filepaths) + self._tar_member_index: Dict[str, Path] = {} + self._tar_local_cache: Dict[str, Path] = {} + + # Build base URL + if base_path: + self.base_url = urlunparse((base_scheme, f"{host}:{port}", base_path, "", "", "")) + else: + self.base_url = f"{base_scheme}://{host}:{port}" + + # Create HTTP client with connection pooling + limits = httpx.Limits(max_keepalive_connections=max_workers, max_connections=max_workers) + self._client = httpx.AsyncClient(limits=limits, timeout=300.0) # 5 min timeout + + LOG.info(f"Initialized NemoASRModel connecting to {self.base_url}") + + @staticmethod + def _is_datastore_path(path: str) -> bool: + """Check if path is a datastore URI (e.g. ais://bucket/object).""" + try: + from nemo.utils.data_utils import is_datastore_path + + return is_datastore_path(path) + except ImportError: + parsed = urlparse(path) + return bool(parsed.scheme) and bool(parsed.netloc) + + @staticmethod + def _download_datastore_object(store_path: str) -> Path: + """Download datastore object to local cache and return local path.""" + scheme = urlparse(store_path).scheme + try: + from nemo.utils.data_utils import DataStoreObject, open_best, resolve_cache_dir + except ImportError as e: + raise RuntimeError( + "Datastore path support requires nemo_toolkit installation. " + "Install nemo_toolkit to use ais:// or other datastore URIs." + ) from e + + # Keep AIStore behavior aligned with NeMo's native DataStoreObject path. + if scheme == "ais": + local_path = DataStoreObject(store_path).get() + if local_path is None: + raise RuntimeError(f"Failed to download datastore object: {store_path}") + return Path(local_path) + + # For non-AIS datastore URIs (e.g., s3://), use NeMo open_best + # (typically backed by Lhotse) and cache file locally. + parsed = urlparse(store_path) + if not parsed.netloc or not parsed.path: + raise ValueError(f"Invalid datastore path format: {store_path}") + rel_path = Path(parsed.netloc + parsed.path) + local_path = resolve_cache_dir() / "datastore" / scheme / rel_path + local_path.parent.mkdir(parents=True, exist_ok=True) + with open_best(store_path, mode="rb") as stream: + data = stream.read() + with open(local_path, "wb") as fout: + fout.write(data) + return local_path + + def _resolve_tarred_audio_files(self, tarred_audio_filepaths: str | list[str] | None) -> list[str]: + """Resolve configured tar paths into an explicit list of tar file references.""" + if tarred_audio_filepaths is None: + return [] + + if isinstance(tarred_audio_filepaths, str): + candidates = [x.strip() for x in tarred_audio_filepaths.split(",") if x.strip()] + else: + candidates = [str(x).strip() for x in tarred_audio_filepaths if str(x).strip()] + + tar_files: list[str] = [] + for candidate in candidates: + if self._is_datastore_path(candidate): + if not candidate.endswith(".tar"): + raise ValueError(f"Datastore tar path must end with .tar: {candidate}") + tar_files.append(candidate) + continue + + candidate_path = Path(candidate).expanduser() + if not candidate_path.is_absolute(): + candidate_path = (Path(self.data_dir) / candidate_path).expanduser() + matches = glob.glob(str(candidate_path)) + if len(matches) == 0: + if candidate_path.is_file(): + matches = [str(candidate_path)] + else: + raise FileNotFoundError(f"Tarred audio path not found: {candidate}") + for path_str in sorted(matches): + path = Path(path_str).expanduser() + if path.suffix == ".tar": + tar_files.append(str(path.absolute())) + + tar_files = list(dict.fromkeys(tar_files)) + if tar_files: + LOG.info(f"Configured {len(tar_files)} tarred audio shard(s) for NemoASRModel") + return tar_files + + def _materialize_tar_path(self, tar_ref: str) -> Path: + """Resolve tar reference (local path or datastore URI) to a local filesystem path.""" + if tar_ref in self._tar_local_cache: + return self._tar_local_cache[tar_ref] + + if self._is_datastore_path(tar_ref): + local_tar_path = self._download_datastore_object(tar_ref) + else: + local_tar_path = Path(tar_ref).expanduser() + if local_tar_path.is_file(): + pass + elif not local_tar_path.is_absolute(): + local_tar_path = (Path(self.data_dir) / local_tar_path).expanduser() + if not local_tar_path.is_file(): + raise FileNotFoundError(f"Tar file not found: {local_tar_path}") + + self._tar_local_cache[tar_ref] = local_tar_path + return local_tar_path + + def _resolve_tar_member(self, member_name: str) -> Path: + """Find tar file containing member by scanning configured tar shards.""" + if member_name in self._tar_member_index: + return self._tar_member_index[member_name] + + for tar_ref in self.tarred_audio_files: + tar_path = self._materialize_tar_path(tar_ref) + with tarfile.open(tar_path, "r") as tar: + try: + member = tar.getmember(member_name) + except KeyError: + continue + if member.isfile(): + self._tar_member_index[member_name] = tar_path + return tar_path + + raise FileNotFoundError( + f"Audio member '{member_name}' was not found in configured tarred_audio_filepaths: " + f"{[str(p) for p in self.tarred_audio_files]}" + ) + + @staticmethod + def _extract_audio_member(tar_path: Path, member_name: str) -> Path: + """Extract a single audio member from tar to a temporary file.""" + suffix = Path(member_name).suffix or ".wav" + with tarfile.open(tar_path, "r") as tar: + file_obj = tar.extractfile(member_name) + if file_obj is None: + raise FileNotFoundError(f"Audio member '{member_name}' not found in tar file '{tar_path}'") + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: + tmp_file.write(file_obj.read()) + return Path(tmp_file.name) + + def _resolve_audio_path(self, prompt: str) -> tuple[Path, Path | None]: + """Resolve prompt to a local audio file, extracting tar member when needed.""" + if ( + self._is_datastore_path(prompt) + and "::" not in prompt + and ".tar:" not in prompt + and not prompt.endswith(".tar") + ): + # Non-tar datastore object (e.g., ais://bucket/audio.wav or s3://bucket/audio.wav) + return self._download_datastore_object(prompt), None + + audio_path = Path(prompt).expanduser() + if not audio_path.is_absolute(): + audio_path = (Path(self.data_dir) / audio_path).expanduser() + if audio_path.is_file(): + return audio_path, None + + tar_part, member_name = None, None + if "::" in prompt: + tar_part, member_name = prompt.split("::", 1) + elif ".tar:" in prompt: + tar_part, member_name = prompt.rsplit(":", 1) + if tar_part is not None and member_name is not None: + tar_path = self._materialize_tar_path(tar_part) + extracted_path = self._extract_audio_member(tar_path, member_name) + return extracted_path, extracted_path + + if self.tarred_audio_files: + tar_path = self._resolve_tar_member(prompt) + extracted_path = self._extract_audio_member(tar_path, prompt) + return extracted_path, extracted_path + + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + @staticmethod + def _extract_shard_id(prompt) -> int | None: + """Extract shard_id from prompt dict if present.""" + if isinstance(prompt, dict) and "shard_id" in prompt: + shard_id = prompt["shard_id"] + if isinstance(shard_id, int): + return shard_id + if isinstance(shard_id, str) and shard_id.isdigit(): + return int(shard_id) + return None + + def _resolve_tar_path_from_shard_id(self, shard_id: int) -> Path | None: + """Resolve shard_id to matching audio_{shard_id}.tar from configured tar shards.""" + expected_name = f"audio_{shard_id}.tar" + for tar_ref in self.tarred_audio_files: + if tar_ref.endswith(expected_name): + return self._materialize_tar_path(tar_ref) + return None + + def _resolve_audio_from_prompt(self, prompt) -> tuple[Path, Path | None]: + """Resolve any supported prompt structure to local audio path.""" + audio_reference = self._extract_audio_reference(prompt) + shard_id = self._extract_shard_id(prompt) + + # Explicit tar-member refs already encode the shard; use generic parsing path. + is_explicit_tar_member_ref = ("::" in audio_reference) or (".tar:" in audio_reference) + if shard_id is not None and self.tarred_audio_files and not is_explicit_tar_member_ref: + shard_tar_path = self._resolve_tar_path_from_shard_id(shard_id) + if shard_tar_path is not None: + extracted_path = self._extract_audio_member(shard_tar_path, audio_reference) + return extracted_path, extracted_path + + return self._resolve_audio_path(audio_reference) + + @staticmethod + def _extract_audio_path_from_message(message: dict) -> str | None: + """Extract first audio path from a single OpenAI-style message dict.""" + if "audios" in message and isinstance(message["audios"], list) and len(message["audios"]) > 0: + first_audio = message["audios"][0] + if isinstance(first_audio, dict) and "path" in first_audio: + return first_audio["path"] + if "audio" in message and isinstance(message["audio"], dict) and "path" in message["audio"]: + return message["audio"]["path"] + return None + + def _extract_audio_reference(self, prompt) -> str: + """Extract audio reference string from supported prompt structures.""" + if isinstance(prompt, str): + return prompt + if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], str): + return prompt[0] + + # OpenAI-style prompt as list of messages + if isinstance(prompt, list): + for message in prompt: + if isinstance(message, dict): + path = self._extract_audio_path_from_message(message) + if path is not None: + return path + raise ValueError("Prompt list does not contain audio path in messages[*].audio(s).path") + + # Native NeMo manifest / full datapoint dict + if isinstance(prompt, dict): + if "messages" in prompt and isinstance(prompt["messages"], list): + return self._extract_audio_reference(prompt["messages"]) + + if "audio_filepath" in prompt and isinstance(prompt["audio_filepath"], str): + return prompt["audio_filepath"] + if "audio_filename" in prompt and isinstance(prompt["audio_filename"], str): + return prompt["audio_filename"] + if "audio_file" in prompt and isinstance(prompt["audio_file"], str): + return prompt["audio_file"] + if "context" in prompt and isinstance(prompt["context"], str): + return prompt["context"] + if "audio_path" in prompt: + audio_path = prompt["audio_path"] + if isinstance(audio_path, str): + return audio_path + if isinstance(audio_path, list) and len(audio_path) > 0 and isinstance(audio_path[0], str): + return audio_path[0] + + path = self._extract_audio_path_from_message(prompt) + if path is not None: + return path + + raise ValueError( + "Prompt dict does not contain supported audio key. Expected one of: " + "messages[*].audio(s).path, audio_filepath, audio_filename, audio_file, context, or audio_path." + ) + + raise TypeError(f"Unsupported prompt type for NemoASRModel: {type(prompt)}") + + async def generate_async(self, prompt, **kwargs): + """Transcribe audio file asynchronously. + + Args: + prompt: Audio reference as path string, OpenAI-style messages, or manifest-like dict. + **kwargs: Generation parameters (most LLM parameters are ignored, use extra_body for ASR options) + + Returns: + dict: Result with 'generation' key containing transcription data + """ + # Validate and warn about unsupported LLM parameters + validate_unsupported_params(kwargs, "NemoASRModel") + + # Parse extra_body for ASR-specific options + extra_body = kwargs.get("extra_body", {}) + language = extra_body.get("language_code", self.language_code) + response_format = extra_body.get("response_format", self.response_format) + enable_timestamps = extra_body.get("enable_timestamps", self.enable_timestamps) + + audio_path, cleanup_path = self._resolve_audio_from_prompt(prompt) + try: + # Check if chunking is needed + chunk_duration = None + if self.enable_audio_chunking: + chunk_duration = await self._check_audio_duration(audio_path) + + # Prepare request + start_time = time.time() + + # Read audio file + async with aiofiles.open(audio_path, "rb") as f: + audio_bytes = await f.read() + + # Prepare multipart form data + files = {"file": (audio_path.name, audio_bytes, "audio/wav")} + + data = { + "model": self.model_name_or_path, + "language": language, + "response_format": response_format, + } + + # Add timestamp granularities if enabled + if enable_timestamps and response_format == "verbose_json": + data["timestamp_granularities"] = "word,segment" + + # Add chunking if needed + if chunk_duration is not None: + data["chunk_duration_sec"] = chunk_duration + + # Make request to server + response = await self._client.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=data) + response.raise_for_status() + + result_data = response.json() + generation_time = time.time() - start_time + + # Extract transcription text + pred_text = result_data["text"] + + # Build result in expected format + result: Dict[str, Any] = { + "pred_text": pred_text, + "generation_time": generation_time, + "audio_file": str(audio_path), + } + + # Add words if available + if "words" in result_data: + result["words"] = result_data["words"] + + # Add any additional metadata + if "language" in result_data: + result["language"] = result_data["language"] + if "duration" in result_data: + result["duration"] = result_data["duration"] + + return {"generation": result} + + except httpx.HTTPStatusError as e: + LOG.error(f"HTTP error during transcription: {e.response.status_code} - {e.response.text}") + raise RuntimeError(f"ASR server error: {e.response.text}") from e + except Exception as e: + LOG.error(f"ASR generation failed: {e}") + raise + finally: + if cleanup_path is not None and cleanup_path.exists(): + cleanup_path.unlink() + + async def _check_audio_duration(self, audio_path: Path) -> Optional[float]: + """Check audio duration and return chunk size if needed. + + Args: + audio_path: Path to audio file + + Returns: + Chunk duration in seconds if chunking is needed, None otherwise + """ + try: + import soundfile as sf + except ImportError: + LOG.warning("soundfile not available, skipping audio duration check") + return None + + try: + info = sf.info(str(audio_path)) + duration = info.duration + + if duration > self.chunk_audio_threshold_sec: + LOG.info( + f"Audio duration ({duration:.1f}s) exceeds threshold " + f"({self.chunk_audio_threshold_sec}s), enabling chunking" + ) + return self.chunk_audio_threshold_sec + + except Exception as e: + LOG.warning(f"Failed to check audio duration: {e}") + + return None + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self._client.aclose() + + def __del__(self): + """Clean up resources.""" + # Close SSH tunnel + if hasattr(self, "_tunnel") and self._tunnel: + try: + self._tunnel.stop() + except Exception: + pass # Ignore errors during cleanup diff --git a/nemo_skills/inference/server/serve_nemo_asr.py b/nemo_skills/inference/server/serve_nemo_asr.py new file mode 100644 index 0000000000..32255fd15c --- /dev/null +++ b/nemo_skills/inference/server/serve_nemo_asr.py @@ -0,0 +1,377 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NeMo ASR server with OpenAI-compatible API. + +This server provides ASR inference using NeMo models (Canary, Parakeet, FastConformer) +with an OpenAI-compatible /v1/audio/transcriptions endpoint. +""" + +import argparse +import inspect +import logging +import os +import tempfile +import time +from typing import Any, Dict, List, Optional + +import uvicorn +from fastapi import FastAPI, File, Form, HTTPException, UploadFile +from fastapi.responses import JSONResponse + +LOG = logging.getLogger(__name__) + + +class NemoASRServer: + """NeMo ASR server handling model loading and inference.""" + + def __init__(self, model_path: str, num_gpus: int = 1): + """Initialize NeMo ASR server. + + Args: + model_path: Path to .nemo checkpoint or NGC model name (e.g., 'nvidia/canary-1b') + num_gpus: Number of GPUs to use (currently only 1 is supported) + """ + self.model_path = model_path + self.num_gpus = num_gpus + self.model = None + self._load_model() + + def _load_model(self): + """Load NeMo ASR model from checkpoint or NGC.""" + try: + import nemo.collections.asr as nemo_asr + except ImportError: + raise ImportError("NeMo toolkit is not installed. Please install it with: pip install nemo_toolkit[asr]") + + LOG.info(f"Loading NeMo ASR model from: {self.model_path}") + start_time = time.time() + + # Check if it's a local .nemo file or NGC model name + if os.path.exists(self.model_path) and self.model_path.endswith(".nemo"): + LOG.info("Loading from local .nemo checkpoint") + self.model = nemo_asr.models.ASRModel.restore_from(self.model_path) + else: + LOG.info("Loading from NGC or model name") + try: + self.model = nemo_asr.models.ASRModel.from_pretrained(self.model_path) + except Exception as e: + LOG.error(f"Failed to load model from NGC: {e}") + LOG.info("Attempting to load as local path...") + self.model = nemo_asr.models.ASRModel.restore_from(self.model_path) + + # Move model to GPU if available + if self.num_gpus > 0: + import torch + + if torch.cuda.is_available(): + self.model = self.model.cuda() + LOG.info("Model moved to GPU") + + self.model.eval() + load_time = time.time() - start_time + LOG.info(f"Model loaded successfully in {load_time:.2f}s") + + @staticmethod + def _extract_first_hypothesis(hypotheses): + """Extract first hypothesis from NeMo transcribe output with validation.""" + if len(hypotheses) == 0: + raise RuntimeError("No transcription returned from model") + first_entry = hypotheses[0] + + # Common NeMo shape: list[Hypothesis] + if hasattr(first_entry, "text"): + return first_entry + + # N-best shape: list[list[Hypothesis]] + if len(first_entry) == 0: + raise RuntimeError("Model returned empty transcription hypotheses") + if hasattr(first_entry[0], "text"): + return first_entry[0] + + raise RuntimeError(f"Unexpected hypothesis structure: {type(first_entry)}") + + def _transcribe_single(self, audio_paths: list[str], enable_timestamps: bool, language: Optional[str] = None): + """Transcribe helper with optional language passthrough when supported by model API.""" + transcribe_kwargs = { + "batch_size": 1, + "return_hypotheses": True, + "timestamps": enable_timestamps, + } + if language: + signature = inspect.signature(self.model.transcribe) + if "language_id" in signature.parameters: + transcribe_kwargs["language_id"] = language + else: + LOG.warning( + f"language='{language}' requested but model.transcribe does not support language_id parameter" + ) + return self.model.transcribe(audio_paths, **transcribe_kwargs) + + async def _transcribe_with_chunking( + self, + audio_path: str, + chunk_duration_sec: float, + enable_timestamps: bool = False, + language: Optional[str] = None, + ) -> tuple[str, float]: + """Transcribe long audio by chunking it into smaller segments. + + Args: + audio_path: Path to audio file + chunk_duration_sec: Duration of each chunk in seconds + enable_timestamps: Whether to enable timestamps + + Returns: + Tuple of (transcribed_text, total_inference_time) + """ + try: + import numpy as np + import soundfile as sf + except ImportError: + raise ImportError("soundfile and numpy are required for audio chunking") + + # Load audio file + audio_array, sampling_rate = sf.read(audio_path) + duration = len(audio_array) / sampling_rate + + LOG.info(f"Chunking audio ({duration:.1f}s) into segments of {chunk_duration_sec}s") + + # Calculate chunks + chunk_samples = int(chunk_duration_sec * sampling_rate) + num_chunks = int(np.ceil(len(audio_array) / chunk_samples)) + + chunks = [] + for i in range(num_chunks): + start = i * chunk_samples + end = min((i + 1) * chunk_samples, len(audio_array)) + chunk = audio_array[start:end] + + # Merge tiny trailing chunks + min_chunk_samples = int(0.5 * sampling_rate) # 0.5 second minimum + if len(chunk) < min_chunk_samples and chunks: + chunks[-1] = np.concatenate([chunks[-1], chunk]) + else: + chunks.append(chunk) + + LOG.info(f"Created {len(chunks)} audio chunks") + + # Transcribe each chunk + chunk_texts = [] + total_time = 0.0 + + for chunk_idx, audio_chunk in enumerate(chunks): + # Save chunk to temporary file + chunk_path = f"{audio_path}.chunk_{chunk_idx}.wav" + try: + sf.write(chunk_path, audio_chunk, sampling_rate) + + # Transcribe chunk + start_time = time.time() + hypotheses = self._transcribe_single([chunk_path], enable_timestamps, language) + chunk_time = time.time() - start_time + total_time += chunk_time + + hypothesis = self._extract_first_hypothesis(hypotheses) + text = hypothesis.text + chunk_texts.append(text.strip()) + LOG.debug(f"Chunk {chunk_idx + 1}/{len(chunks)}: {text[:50]}...") + + finally: + # Clean up chunk file + if os.path.exists(chunk_path): + os.unlink(chunk_path) + + # Concatenate all chunk transcriptions + full_text = " ".join(chunk_texts) + + return full_text, total_time + + async def transcribe( + self, + audio_file: UploadFile, + language: Optional[str] = None, + response_format: str = "json", + timestamp_granularities: Optional[List[str]] = None, + chunk_duration_sec: Optional[float] = None, + ) -> Dict[str, Any]: + """Transcribe audio file using NeMo ASR model. + + Args: + audio_file: Audio file to transcribe + language: Language code (e.g., 'en', 'es') - optional + response_format: 'json' or 'verbose_json' + timestamp_granularities: List of granularities ['word', 'segment'] + chunk_duration_sec: If specified, chunk audio into segments of this duration + + Returns: + Transcription result in OpenAI-compatible format + """ + # Save uploaded file to temporary location + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: + content = await audio_file.read() + tmp_file.write(content) + tmp_path = tmp_file.name + + try: + # Determine if timestamps are needed + enable_timestamps = ( + timestamp_granularities is not None + and len(timestamp_granularities) > 0 + and response_format == "verbose_json" + ) + hypothesis = None + + # Handle chunking if requested + if chunk_duration_sec is not None and chunk_duration_sec > 0: + text, inference_time = await self._transcribe_with_chunking( + tmp_path, chunk_duration_sec, enable_timestamps, language + ) + if enable_timestamps: + LOG.warning("Word-level timestamps are not available when chunk_duration_sec is enabled.") + else: + # Transcribe using NeMo + start_time = time.time() + hypotheses = self._transcribe_single([tmp_path], enable_timestamps, language) + inference_time = time.time() - start_time + + # Extract transcription + hypothesis = self._extract_first_hypothesis(hypotheses) # [batch_idx][hypothesis_idx] + text = hypothesis.text + + # Build response based on format + result = {"text": text} + + if response_format == "verbose_json": + result["task"] = "transcribe" + result["duration"] = None # Could compute from audio file if needed + + # Add language if detected/specified + if language: + result["language"] = language + + # Add timestamps if requested + if enable_timestamps and hypothesis is not None and hasattr(hypothesis, "timestep"): + words = [] + + # Extract word-level timestamps + if "word" in timestamp_granularities: + word_timestamps = getattr(hypothesis.timestep, "word", []) + for word_info in word_timestamps: + words.append( + { + "word": word_info["word"], + "start": word_info["start_offset"], + "end": word_info["end_offset"], + } + ) + + if words: + result["words"] = words + + # Add inference metadata + result["inference_time"] = inference_time + + return result + + finally: + # Clean up temporary file + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + +def create_app(model_path: str, num_gpus: int = 1) -> FastAPI: + """Create FastAPI application with NeMo ASR server. + + Args: + model_path: Path to model or NGC name + num_gpus: Number of GPUs to use + + Returns: + FastAPI application + """ + app = FastAPI( + title="NeMo ASR Server", description="OpenAI-compatible ASR server using NeMo models", version="1.0.0" + ) + + # Initialize server + server = NemoASRServer(model_path, num_gpus) + + @app.get("/health") + async def health(): + """Health check endpoint.""" + return {"status": "healthy", "model": model_path} + + @app.post("/v1/audio/transcriptions") + async def create_transcription( + file: UploadFile = File(..., description="Audio file to transcribe"), + model: str = Form(default="nemo-asr", description="Model to use (ignored, using server model)"), + language: Optional[str] = Form(default=None, description="Language code"), + response_format: str = Form(default="json", description="Response format: json or verbose_json"), + timestamp_granularities: Optional[str] = Form(default=None, description="Comma-separated list: word,segment"), + chunk_duration_sec: Optional[float] = Form( + default=None, description="If specified, chunk audio into segments of this duration (in seconds)" + ), + ): + """Transcribe audio file. + + OpenAI-compatible endpoint for audio transcription. + """ + try: + if model != "nemo-asr": + LOG.debug(f"Ignoring request model='{model}', serving with loaded model='{model_path}'") + # Parse timestamp granularities + granularities = None + if timestamp_granularities: + granularities = [g.strip() for g in timestamp_granularities.split(",")] + + result = await server.transcribe( + audio_file=file, + language=language, + response_format=response_format, + timestamp_granularities=granularities, + chunk_duration_sec=chunk_duration_sec, + ) + + return JSONResponse(content=result) + + except Exception as e: + LOG.error(f"Transcription failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + return app + + +def main(): + """Main entry point for NeMo ASR server.""" + parser = argparse.ArgumentParser(description="Serve NeMo ASR model") + parser.add_argument("--model", required=True, help="Path to model or NGC model name") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use") + parser.add_argument("--num_nodes", type=int, default=1, help="Number of nodes (not used, for compatibility)") + parser.add_argument("--port", type=int, default=5000, help="Server port") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host") + args = parser.parse_args() + + # Setup logging + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + + # Create and run app + app = create_app(args.model, args.num_gpus) + + LOG.info(f"Starting NeMo ASR server on {args.host}:{args.port}") + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/nemo_skills/pipeline/utils/server.py b/nemo_skills/pipeline/utils/server.py index 9fb84a680f..c791d146fd 100644 --- a/nemo_skills/pipeline/utils/server.py +++ b/nemo_skills/pipeline/utils/server.py @@ -27,6 +27,7 @@ class SupportedServersSelfHosted(str, Enum): vllm_multimodal = "vllm_multimodal" sglang = "sglang" megatron = "megatron" + nemo_asr = "nemo_asr" generic = "generic" @@ -36,6 +37,7 @@ class SupportedServers(str, Enum): vllm_multimodal = "vllm_multimodal" sglang = "sglang" megatron = "megatron" + nemo_asr = "nemo_asr" openai = "openai" azureopenai = "azureopenai" gemini = "gemini" @@ -125,9 +127,9 @@ def get_server_command( ): num_tasks = num_gpus - # check if the model path is mounted if not vllm, sglang, or trtllm; - # vllm, sglang, trtllm can also pass model name as "model_path" so we need special processing - if server_type not in ["vllm", "vllm_multimodal", "sglang", "trtllm", "generic"]: + # check if the model path is mounted if not vllm/sglang/trtllm-like servers; + # these server types can also pass model name as "model_path" so we need special processing. + if server_type not in ["vllm", "vllm_multimodal", "sglang", "trtllm", "nemo_asr", "generic"]: check_if_mounted(cluster_config, model_path) # the model path will be mounted, so generally it will start with / @@ -209,6 +211,17 @@ def get_server_command( num_tasks = 1 else: num_tasks = num_gpus + elif server_type == "nemo_asr": + server_entrypoint = server_entrypoint or "-m nemo_skills.inference.server.serve_nemo_asr" + server_start_cmd = ( + f"python3 {server_entrypoint} " + f" --model {model_path} " + f" --num_gpus {num_gpus} " + f" --num_nodes {num_nodes} " + f" --port {server_port} " + f" {server_args} " + ) + num_tasks = 1 elif server_type == "generic": if not server_entrypoint: raise ValueError("For 'generic' server type, 'server_entrypoint' must be specified.") diff --git a/tests/test_nemo_asr_support.py b/tests/test_nemo_asr_support.py new file mode 100644 index 0000000000..00e6607a16 --- /dev/null +++ b/tests/test_nemo_asr_support.py @@ -0,0 +1,253 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import io +import tarfile +from types import SimpleNamespace +from unittest.mock import patch + +from nemo_skills.inference.model.nemo_asr import NemoASRModel +from nemo_skills.inference.server.serve_nemo_asr import NemoASRServer +from nemo_skills.pipeline.utils.server import SupportedServers + + +def _close_model(model: NemoASRModel): + asyncio.run(model._client.aclose()) + + +def test_nemo_asr_resolve_audio_path_from_tarred_member(tmp_path): + member_name = "sample.wav" + member_bytes = b"fake wav content" + tar_path = tmp_path / "shard_0.tar" + + with tarfile.open(tar_path, "w") as tar: + info = tarfile.TarInfo(name=member_name) + info.size = len(member_bytes) + tar.addfile(info, fileobj=io.BytesIO(member_bytes)) + + model = NemoASRModel(data_dir=str(tmp_path), tarred_audio_filepaths="shard_0.tar") + extracted_path, cleanup_path = model._resolve_audio_path(member_name) + try: + assert extracted_path.is_file() + assert cleanup_path == extracted_path + assert extracted_path.read_bytes() == member_bytes + finally: + extracted_path.unlink(missing_ok=True) + _close_model(model) + + +def test_nemo_asr_resolve_audio_path_from_explicit_tar_prompt(tmp_path): + member_name = "nested/sample.wav" + member_bytes = b"fake wav content 2" + tar_path = tmp_path / "shard_1.tar" + + with tarfile.open(tar_path, "w") as tar: + info = tarfile.TarInfo(name=member_name) + info.size = len(member_bytes) + tar.addfile(info, fileobj=io.BytesIO(member_bytes)) + + model = NemoASRModel(data_dir=str(tmp_path)) + extracted_path, cleanup_path = model._resolve_audio_path(f"shard_1.tar::{member_name}") + try: + assert extracted_path.is_file() + assert cleanup_path == extracted_path + assert extracted_path.read_bytes() == member_bytes + finally: + extracted_path.unlink(missing_ok=True) + _close_model(model) + + +def test_nemo_asr_server_extract_first_hypothesis_validation(): + hypothesis = SimpleNamespace(text="hello") + assert NemoASRServer._extract_first_hypothesis([hypothesis]) is hypothesis + assert NemoASRServer._extract_first_hypothesis([[hypothesis]]) is hypothesis + + try: + NemoASRServer._extract_first_hypothesis([]) + raise AssertionError("Expected RuntimeError for empty hypotheses") + except RuntimeError as e: + assert "No transcription returned" in str(e) + + try: + NemoASRServer._extract_first_hypothesis([[]]) + raise AssertionError("Expected RuntimeError for empty inner hypotheses") + except RuntimeError as e: + assert "empty transcription hypotheses" in str(e) + + +def test_nemo_asr_server_transcribe_single_passes_language_id_when_supported(): + class _FakeModel: + def transcribe(self, paths, batch_size, return_hypotheses, timestamps, language_id=None): + assert paths == ["a.wav"] + assert batch_size == 1 + assert return_hypotheses is True + assert timestamps is False + assert language_id == "el" + return [[SimpleNamespace(text="ok")]] + + server = NemoASRServer.__new__(NemoASRServer) + server.model = _FakeModel() + output = server._transcribe_single(["a.wav"], enable_timestamps=False, language="el") + assert output[0][0].text == "ok" + + +def test_nemo_asr_server_transcribe_single_without_language_id_param(): + class _FakeModel: + def transcribe(self, paths, batch_size, return_hypotheses, timestamps): + assert paths == ["a.wav"] + return [SimpleNamespace(text="ok")] + + server = NemoASRServer.__new__(NemoASRServer) + server.model = _FakeModel() + output = server._transcribe_single(["a.wav"], enable_timestamps=True, language="el") + assert output[0].text == "ok" + + +def test_nemo_asr_resolve_audio_path_from_s3_tar_prompt(tmp_path): + member_name = "sample.wav" + member_bytes = b"fake wav content s3" + local_tar = tmp_path / "cached_s3_shard.tar" + s3_tar_uri = "s3://my-bucket/audio_0.tar" + + with tarfile.open(local_tar, "w") as tar: + info = tarfile.TarInfo(name=member_name) + info.size = len(member_bytes) + tar.addfile(info, fileobj=io.BytesIO(member_bytes)) + + model = NemoASRModel(data_dir=str(tmp_path)) + try: + with ( + patch.object(NemoASRModel, "_is_datastore_path", side_effect=lambda p: p.startswith("s3://")), + patch.object(NemoASRModel, "_download_datastore_object", return_value=local_tar), + ): + extracted_path, cleanup_path = model._resolve_audio_path(f"{s3_tar_uri}:{member_name}") + assert extracted_path.is_file() + assert cleanup_path == extracted_path + assert extracted_path.read_bytes() == member_bytes + extracted_path.unlink(missing_ok=True) + finally: + _close_model(model) + + +def test_nemo_asr_extract_audio_reference_from_openai_messages(tmp_path): + model = NemoASRModel(data_dir=str(tmp_path)) + try: + prompt = [ + {"role": "system", "content": "Answer the questions."}, + { + "role": "user", + "content": "Transcribe", + "audio": {"path": "/data/example.wav", "duration": 1.0}, + "audios": [{"path": "/data/example.wav", "duration": 1.0}], + }, + ] + assert model._extract_audio_reference(prompt) == "/data/example.wav" + finally: + _close_model(model) + + +def test_nemo_asr_extract_audio_reference_from_native_manifest_dict(tmp_path): + model = NemoASRModel(data_dir=str(tmp_path)) + try: + prompt = { + "audio_filepath": "/data/native.wav", + "duration": 1.23, + "text": "hello", + } + assert model._extract_audio_reference(prompt) == "/data/native.wav" + + prompt_audio_path = { + "audio_path": ["/data/list_based.wav"], + "text": "hello", + } + assert model._extract_audio_reference(prompt_audio_path) == "/data/list_based.wav" + + prompt_audio_filename = { + "audio_filename": "relative_name.wav", + "duration": 1.0, + } + assert model._extract_audio_reference(prompt_audio_filename) == "relative_name.wav" + + prompt_context = { + "context": "ais://bucket/path/audio.wav", + "duration": 1.0, + } + assert model._extract_audio_reference(prompt_context) == "ais://bucket/path/audio.wav" + finally: + _close_model(model) + + +def test_nemo_asr_resolve_audio_path_from_shard_id_optimized(tmp_path): + member_name = "manifest_member.flac" + member_bytes = b"fake wav content shard" + tar_path = tmp_path / "audio_9.tar" + + with tarfile.open(tar_path, "w") as tar: + info = tarfile.TarInfo(name=member_name) + info.size = len(member_bytes) + tar.addfile(info, fileobj=io.BytesIO(member_bytes)) + + model = NemoASRModel(data_dir=str(tmp_path), tarred_audio_filepaths="audio_*.tar") + try: + prompt = {"audio_filepath": member_name, "shard_id": 9} + audio_reference = model._extract_audio_reference(prompt) + shard_tar_path = model._resolve_tar_path_from_shard_id(model._extract_shard_id(prompt)) + extracted_path = model._extract_audio_member(shard_tar_path, audio_reference) + assert extracted_path.is_file() + assert extracted_path.read_bytes() == member_bytes + extracted_path.unlink(missing_ok=True) + finally: + _close_model(model) + + +def test_nemo_asr_shard_id_does_not_override_explicit_tar_member_prompt(tmp_path): + member_name = "member_in_tar.flac" + member_bytes = b"member bytes" + local_tar = tmp_path / "audio_9.tar" + s3_tar_uri = "s3://my-bucket/audio_9.tar" + + with tarfile.open(local_tar, "w") as tar: + info = tarfile.TarInfo(name=member_name) + info.size = len(member_bytes) + tar.addfile(info, fileobj=io.BytesIO(member_bytes)) + + model = NemoASRModel(data_dir=str(tmp_path), tarred_audio_filepaths=str(local_tar)) + try: + prompt = {"audio_filepath": f"{s3_tar_uri}:{member_name}", "shard_id": 9} + with ( + patch.object(NemoASRModel, "_is_datastore_path", side_effect=lambda p: p.startswith("s3://")), + patch.object(NemoASRModel, "_download_datastore_object", return_value=local_tar), + ): + extracted_path, cleanup_path = model._resolve_audio_from_prompt(prompt) + assert extracted_path.is_file() + assert cleanup_path == extracted_path + assert extracted_path.read_bytes() == member_bytes + extracted_path.unlink(missing_ok=True) + finally: + _close_model(model) + + +def test_supported_servers_includes_nemo_asr(): + assert SupportedServers.nemo_asr.value == "nemo_asr" + + +def test_nemo_asr_base_url_with_scheme_port_and_path(tmp_path): + model = NemoASRModel(base_url="https://host.example:5000/v1", data_dir=str(tmp_path)) + try: + assert model.server_host == "host.example" + assert model.server_port == "5000" + assert model.base_url == "https://host.example:5000/v1" + finally: + _close_model(model)