Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions nemo_skills/inference/model/vllm_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import base64
import copy
import io
import json
import logging
import os
Expand Down Expand Up @@ -100,6 +101,14 @@ def _parse_chat_completion_response(self, response, include_response: bool = Fal
except json.JSONDecodeError:
LOG.warning("Failed to parse debug_info JSON from content")

# Save codec data if present in debug_info (for FCD scoring)
if "debug_info" in result and result["debug_info"].get("codec_data"):
codec_codes_path = self._save_codec_data(result["debug_info"]["codec_data"], response.id)
if codec_codes_path:
result["debug_info"]["codec_codes_path"] = codec_codes_path
# Remove base64 data to avoid duplication in output
del result["debug_info"]["codec_data"]

choice = response.choices[0]
if hasattr(choice.message, "audio") and choice.message.audio:
audio_result = self._process_audio_response(choice.message.audio, response.id)
Expand Down Expand Up @@ -404,3 +413,23 @@ def _build_chat_request_params(
messages = [self.content_text_to_list(copy.deepcopy(msg)) for msg in messages]
messages = self._preprocess_messages_for_model(messages)
return super()._build_chat_request_params(messages=messages, **kwargs)

def _save_codec_data(self, codec_base64: str, response_id: str) -> str:
"""Save codec data (.pt file) to disk and return the path."""
if not self.output_audio_dir:
return None

try:
import torch

codec_bytes = base64.b64decode(codec_base64)
buf = io.BytesIO(codec_bytes)
codes = torch.load(buf, map_location="cpu")

filename = f"{response_id}.pt"
filepath = os.path.join(self.output_audio_dir, filename)
torch.save(codes, filepath)
return filepath
except Exception as e:
LOG.warning(f"Failed to save codec data: {e}")
return None
78 changes: 50 additions & 28 deletions recipes/multimodal/server/backends/magpie_tts_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class MagpieTTSConfig(BackendConfig):
max_decoder_steps: int = 440
use_local_transformer: bool = False
output_sample_rate: int = 22050
save_codes: bool = False # Save codec codes for FCD scoring
longform_mode: str = "auto" # "auto" | "always" | "never" - longform inference mode
# Checkpoint loading options (alternative to model_path .nemo file)
hparams_file: Optional[str] = None
checkpoint_file: Optional[str] = None
Expand All @@ -52,6 +54,8 @@ def from_dict(cls, d: Dict[str, Any]) -> "MagpieTTSConfig":
"max_decoder_steps",
"use_local_transformer",
"output_sample_rate",
"save_codes",
"longform_mode",
"hparams_file",
"checkpoint_file",
"legacy_codebooks",
Expand Down Expand Up @@ -164,25 +168,28 @@ def _load_fsspec_patched(path: str, map_location: str = None, **kwargs):
)
self._model, self._checkpoint_name = load_magpie_model(cfg, device=self.config.device)

self._runner = MagpieInferenceRunner(
self._model,
InferenceConfig(
temperature=self.tts_config.temperature,
topk=self.tts_config.top_k,
max_decoder_steps=self.tts_config.max_decoder_steps,
use_cfg=self.tts_config.use_cfg,
cfg_scale=self.tts_config.cfg_scale,
use_local_transformer=self.tts_config.use_local_transformer,
batch_size=16,
),
# Build InferenceConfig with nested ModelInferenceParameters
from nemo.collections.tts.models.magpietts import ModelInferenceParameters

model_params = ModelInferenceParameters(
temperature=self.tts_config.temperature,
topk=self.tts_config.top_k,
cfg_scale=self.tts_config.cfg_scale,
max_decoder_steps=self.tts_config.max_decoder_steps,
)
inference_config = InferenceConfig(
batch_size=16,
use_cfg=self.tts_config.use_cfg,
use_local_transformer=self.tts_config.use_local_transformer,
model_inference_parameters=model_params,
longform_mode=self.tts_config.longform_mode,
)

self._runner = MagpieInferenceRunner(self._model, inference_config)

self._temp_dir = tempfile.mkdtemp(prefix="magpie_tts_")
self.tts_config.output_sample_rate = self._model.sample_rate
self._is_loaded = True
print(
f"[MagpieTTSBackend] Loaded: {self._checkpoint_name}, sr={self._model.sample_rate}, cfg={self.tts_config.use_cfg}"
)

def _extract_json(self, text: str) -> dict:
"""Extract JSON object from text, skipping non-JSON parts."""
Expand Down Expand Up @@ -262,8 +269,9 @@ def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]:
from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config

dataset = self._runner.create_dataset(load_evalset_config(config_path))
rtf_list, _ = self._runner.run_inference_on_dataset(
dataset, output_dir, save_cross_attention_maps=False, save_context_audio=False
rtf_list, *_ = self._runner.run_inference_on_dataset(
dataset, output_dir, save_cross_attention_maps=False, save_context_audio=False,
save_predicted_codes=self.tts_config.save_codes,
)

gen_time = time.time() - start_time
Expand All @@ -283,6 +291,31 @@ def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]:
sf.write(buf, audio, sr, format="WAV")
buf.seek(0)
dur = len(audio) / sr

debug_info = {
"checkpoint": self._checkpoint_name,
"audio_duration_sec": dur,
"rtf": gen_time / len(requests) / dur if dur else 0,
"config": {
"temp": self.tts_config.temperature,
"top_k": self.tts_config.top_k,
"cfg": self.tts_config.use_cfg,
"cfg_scale": self.tts_config.cfg_scale,
},
"batch_metrics": batch_metrics,
}

# Include codec data if save_codes is enabled (for FCD scoring)
if self.tts_config.save_codes:
codes_path = os.path.join(output_dir, f"predicted_codes_{i}.pt")
if os.path.exists(codes_path):
import base64
import torch
codes_buf = io.BytesIO()
torch.save(torch.load(codes_path, map_location="cpu"), codes_buf)
codes_buf.seek(0)
debug_info["codec_data"] = base64.b64encode(codes_buf.read()).decode("utf-8")

results.append(
GenerationResult(
text=parsed[i].get("text", ""),
Expand All @@ -291,18 +324,7 @@ def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]:
audio_format="wav",
request_id=req.request_id,
generation_time_ms=gen_time * 1000 / len(requests),
debug_info={
"checkpoint": self._checkpoint_name,
"audio_duration_sec": dur,
"rtf": gen_time / len(requests) / dur if dur else 0,
"config": {
"temp": self.tts_config.temperature,
"top_k": self.tts_config.top_k,
"cfg": self.tts_config.use_cfg,
"cfg_scale": self.tts_config.cfg_scale,
},
"batch_metrics": batch_metrics,
},
debug_info=debug_info,
)
)
else:
Expand Down