Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions recipes/multimodal/server/backends/magpie_tts_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import shutil
import tempfile
import time
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Set

import soundfile as sf
Expand All @@ -22,12 +22,13 @@
@dataclass
class MagpieTTSConfig(BackendConfig):
codec_model_path: Optional[str] = None
top_k: int = 80
temperature: float = 0.6
max_decoder_steps: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
cfg_scale: Optional[float] = None
use_cfg: bool = True
cfg_scale: float = 2.5
max_decoder_steps: int = 440
use_local_transformer: bool = False
use_local_transformer: bool = True
apply_attention_prior: bool = True
output_sample_rate: int = 22050
# Checkpoint loading options (alternative to model_path .nemo file)
hparams_file: Optional[str] = None
Expand Down Expand Up @@ -140,6 +141,7 @@ def _load_fsspec_patched(path: str, map_location: str = None, **kwargs):
except Exception:
pass

from nemo.collections.tts.models.magpietts import ModelInferenceParameters
from nemo.collections.tts.modules.magpietts_inference.inference import InferenceConfig, MagpieInferenceRunner
from nemo.collections.tts.modules.magpietts_inference.utils import ModelLoadConfig, load_magpie_model

Expand All @@ -166,19 +168,23 @@ def _load_fsspec_patched(path: str, map_location: str = None, **kwargs):
)
self._model, self._checkpoint_name = load_magpie_model(cfg, device=self.config.device)

self._runner = MagpieInferenceRunner(
self._model,
InferenceConfig(
temperature=self.tts_config.temperature,
topk=self.tts_config.top_k,
max_decoder_steps=self.tts_config.max_decoder_steps,
use_cfg=self.tts_config.use_cfg,
cfg_scale=self.tts_config.cfg_scale,
use_local_transformer=self.tts_config.use_local_transformer,
batch_size=16,
),
# Merge args from MagpieTTSConfig into InferenceConfig
model_inference_parameters = {}
for field in fields(ModelInferenceParameters):
field = field.name
override_arg = getattr(self.tts, field)
if override_arg is not None:
model_inference_parameters[field] = override_arg
inference_config = InferenceConfig(
model_inference_parameters=ModelInferenceParameters.from_dict(model_inference_parameters),
batch_size=16,
use_cfg=self.tts_config.use_cfg,
use_local_transformer=self.tts_config.use_local_transformer,
apply_attention_prior=self.tts_config.apply_attention_prior,
)

self._runner = MagpieInferenceRunner(self._model, inference_config)

self._temp_dir = tempfile.mkdtemp(prefix="magpie_tts_")
self.tts_config.output_sample_rate = self._model.sample_rate
self._is_loaded = True
Expand Down
Loading