diff --git a/nemo_skills/inference/model/vllm_multimodal.py b/nemo_skills/inference/model/vllm_multimodal.py index c2285b0ccc..0348c6ec22 100644 --- a/nemo_skills/inference/model/vllm_multimodal.py +++ b/nemo_skills/inference/model/vllm_multimodal.py @@ -21,6 +21,7 @@ import base64 import copy +import io import json import logging import os @@ -100,6 +101,14 @@ def _parse_chat_completion_response(self, response, include_response: bool = Fal except json.JSONDecodeError: LOG.warning("Failed to parse debug_info JSON from content") + # Save codec data if present in debug_info (for FCD scoring) + if "debug_info" in result and result["debug_info"].get("codec_data"): + codec_codes_path = self._save_codec_data(result["debug_info"]["codec_data"], response.id) + if codec_codes_path: + result["debug_info"]["codec_codes_path"] = codec_codes_path + # Remove base64 data to avoid duplication in output + del result["debug_info"]["codec_data"] + choice = response.choices[0] if hasattr(choice.message, "audio") and choice.message.audio: audio_result = self._process_audio_response(choice.message.audio, response.id) @@ -404,3 +413,23 @@ def _build_chat_request_params( messages = [self.content_text_to_list(copy.deepcopy(msg)) for msg in messages] messages = self._preprocess_messages_for_model(messages) return super()._build_chat_request_params(messages=messages, **kwargs) + + def _save_codec_data(self, codec_base64: str, response_id: str) -> str: + """Save codec data (.pt file) to disk and return the path.""" + if not self.output_audio_dir: + return None + + try: + import torch + + codec_bytes = base64.b64decode(codec_base64) + buf = io.BytesIO(codec_bytes) + codes = torch.load(buf, map_location="cpu") + + filename = f"{response_id}.pt" + filepath = os.path.join(self.output_audio_dir, filename) + torch.save(codes, filepath) + return filepath + except Exception as e: + LOG.warning(f"Failed to save codec data: {e}") + return None diff --git a/recipes/multimodal/server/backends/magpie_tts_backend.py b/recipes/multimodal/server/backends/magpie_tts_backend.py index 4d5b910d64..a6877bd363 100644 --- a/recipes/multimodal/server/backends/magpie_tts_backend.py +++ b/recipes/multimodal/server/backends/magpie_tts_backend.py @@ -29,6 +29,8 @@ class MagpieTTSConfig(BackendConfig): max_decoder_steps: int = 440 use_local_transformer: bool = False output_sample_rate: int = 22050 + save_codes: bool = False # Save codec codes for FCD scoring + longform_mode: str = "auto" # "auto" | "always" | "never" - longform inference mode # Checkpoint loading options (alternative to model_path .nemo file) hparams_file: Optional[str] = None checkpoint_file: Optional[str] = None @@ -52,6 +54,8 @@ def from_dict(cls, d: Dict[str, Any]) -> "MagpieTTSConfig": "max_decoder_steps", "use_local_transformer", "output_sample_rate", + "save_codes", + "longform_mode", "hparams_file", "checkpoint_file", "legacy_codebooks", @@ -164,25 +168,28 @@ def _load_fsspec_patched(path: str, map_location: str = None, **kwargs): ) self._model, self._checkpoint_name = load_magpie_model(cfg, device=self.config.device) - self._runner = MagpieInferenceRunner( - self._model, - InferenceConfig( - temperature=self.tts_config.temperature, - topk=self.tts_config.top_k, - max_decoder_steps=self.tts_config.max_decoder_steps, - use_cfg=self.tts_config.use_cfg, - cfg_scale=self.tts_config.cfg_scale, - use_local_transformer=self.tts_config.use_local_transformer, - batch_size=16, - ), + # Build InferenceConfig with nested ModelInferenceParameters + from nemo.collections.tts.models.magpietts import ModelInferenceParameters + + model_params = ModelInferenceParameters( + temperature=self.tts_config.temperature, + topk=self.tts_config.top_k, + cfg_scale=self.tts_config.cfg_scale, + max_decoder_steps=self.tts_config.max_decoder_steps, + ) + inference_config = InferenceConfig( + batch_size=16, + use_cfg=self.tts_config.use_cfg, + use_local_transformer=self.tts_config.use_local_transformer, + model_inference_parameters=model_params, + longform_mode=self.tts_config.longform_mode, ) + self._runner = MagpieInferenceRunner(self._model, inference_config) + self._temp_dir = tempfile.mkdtemp(prefix="magpie_tts_") self.tts_config.output_sample_rate = self._model.sample_rate self._is_loaded = True - print( - f"[MagpieTTSBackend] Loaded: {self._checkpoint_name}, sr={self._model.sample_rate}, cfg={self.tts_config.use_cfg}" - ) def _extract_json(self, text: str) -> dict: """Extract JSON object from text, skipping non-JSON parts.""" @@ -262,8 +269,9 @@ def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config dataset = self._runner.create_dataset(load_evalset_config(config_path)) - rtf_list, _ = self._runner.run_inference_on_dataset( - dataset, output_dir, save_cross_attention_maps=False, save_context_audio=False + rtf_list, *_ = self._runner.run_inference_on_dataset( + dataset, output_dir, save_cross_attention_maps=False, save_context_audio=False, + save_predicted_codes=self.tts_config.save_codes, ) gen_time = time.time() - start_time @@ -283,6 +291,31 @@ def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: sf.write(buf, audio, sr, format="WAV") buf.seek(0) dur = len(audio) / sr + + debug_info = { + "checkpoint": self._checkpoint_name, + "audio_duration_sec": dur, + "rtf": gen_time / len(requests) / dur if dur else 0, + "config": { + "temp": self.tts_config.temperature, + "top_k": self.tts_config.top_k, + "cfg": self.tts_config.use_cfg, + "cfg_scale": self.tts_config.cfg_scale, + }, + "batch_metrics": batch_metrics, + } + + # Include codec data if save_codes is enabled (for FCD scoring) + if self.tts_config.save_codes: + codes_path = os.path.join(output_dir, f"predicted_codes_{i}.pt") + if os.path.exists(codes_path): + import base64 + import torch + codes_buf = io.BytesIO() + torch.save(torch.load(codes_path, map_location="cpu"), codes_buf) + codes_buf.seek(0) + debug_info["codec_data"] = base64.b64encode(codes_buf.read()).decode("utf-8") + results.append( GenerationResult( text=parsed[i].get("text", ""), @@ -291,18 +324,7 @@ def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]: audio_format="wav", request_id=req.request_id, generation_time_ms=gen_time * 1000 / len(requests), - debug_info={ - "checkpoint": self._checkpoint_name, - "audio_duration_sec": dur, - "rtf": gen_time / len(requests) / dur if dur else 0, - "config": { - "temp": self.tts_config.temperature, - "top_k": self.tts_config.top_k, - "cfg": self.tts_config.use_cfg, - "cfg_scale": self.tts_config.cfg_scale, - }, - "batch_metrics": batch_metrics, - }, + debug_info=debug_info, ) ) else: