Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions instruct_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
]
```

* Note: Change the user prompt from "Please convert this text to speech" to "请将这段文字转换为语音" when synthesizing Chinese text, to better align with the base model's training data.

### Instruct TTS

```
Expand Down
11 changes: 7 additions & 4 deletions mimo_audio_train/models/mimo_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from peft import get_peft_model
from mimo_audio_train.arguments import CustomArguments

from .src_mimo_audio.process_speechdata import InputSegment, StreamingInputSegment
from .src_mimo_audio.mimo_audio_tokenizer import MiMoAudioTokenizer
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self.eot_idx = self.tokenizer.convert_tokens_to_ids("<|eot|>")
self.im_start_idx = self.tokenizer.convert_tokens_to_ids("<|im_start|>")
self.im_end_idx = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
self.speech_loss_weights = CustomArguments.speech_loss_weights

model_args = MiMoAudioArguments(
model_name_or_path=self.path,
Expand All @@ -84,6 +86,7 @@ def __init__(
sostm_idx=self.sostm_idx,
eostm_idx=self.eostm_idx,
eot_idx=self.eot_idx,
speech_loss_weights=self.speech_loss_weights,
)

start_loading_time = time.monotonic()
Expand Down Expand Up @@ -510,9 +513,9 @@ def get_tts_sft_prompt(
else:
language = detect_language(input)
if language == "zh":
template = "请将这段文字转换为语音"
template = "请将这段文字转换为语音"
else:
template = "Please convert this text to speech."
template = "Please convert this text to speech"

text = self.preprocess_input(input)
if instruct is None:
Expand Down Expand Up @@ -1397,7 +1400,7 @@ def forward(
add_history=False,
task_name=None,
):

task_sampler = self.get_task_sampler(task_name)

generation_kwargs = self.generate_kwargs.copy()
Expand Down Expand Up @@ -1485,7 +1488,7 @@ def forward(
return detokenized_text
else:
return wav_concat

def asr_sft(self, audio, lang='zh'):
stopping_criteria = [
MiMoStopper(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ def forward(
use_cache=True,
return_dict=True,
cache_position=cache_position,
is_causal=True,
)
hidden_states = outputs.last_hidden_state # [B, new_T_group, hidden_size]

Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ hyperpyyaml==1.2.2
loguru==0.7.3
sox==1.5.0
s3prl==0.4.18
timm
timm
peft
deepspeed