diff --git a/app.py b/app.py index c92f4e33..32b9d70e 100644 --- a/app.py +++ b/app.py @@ -1,8 +1,13 @@ import argparse import tempfile import time +import queue +import threading +import os +import io +import base64 from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Generator, Dict, Any import gradio as gr import numpy as np @@ -46,6 +51,32 @@ raise +def create_audio_html(audio_data, sample_rate=44100): + """ + Create an HTML audio element with the audio data + """ + # Convert audio to 16-bit PCM WAV + audio_int16 = (audio_data * 32767).astype(np.int16) + + # Create in-memory WAV file + wav_bytes = io.BytesIO() + with sf.SoundFile(wav_bytes, mode='w', samplerate=sample_rate, channels=1, format='WAV', subtype='PCM_16') as f: + f.write(audio_int16) + + # Encode as base64 + wav_bytes.seek(0) + base64_audio = base64.b64encode(wav_bytes.read()).decode('utf-8') + + # Create HTML audio element with autoplay + audio_html = f""" + + """ + return audio_html + + def run_inference( text_input: str, audio_prompt_input: Optional[Tuple[int, np.ndarray]], @@ -55,6 +86,11 @@ def run_inference( top_p: float, cfg_filter_top_k: int, speed_factor: float, + use_streaming: bool = False, + autoplay_chunks: bool = False, + chunk_size: int = 100, + overlap: int = 10, + progress=gr.Progress(), ): """ Runs Nari inference using the globally loaded model and provided inputs. @@ -64,10 +100,20 @@ def run_inference( if not text_input or text_input.isspace(): raise gr.Error("Text input cannot be empty.") + + # Show different progress message based on mode + if use_streaming: + progress(0, desc="Preparing streaming generation...") + else: + progress(0, desc="Preparing generation...") temp_txt_file_path = None temp_audio_prompt_path = None output_audio = (44100, np.zeros(1, dtype=np.float32)) + + # For streaming with real-time playback + audio_element_html = "" + streaming_outputs = [] try: prompt_path_for_generate = None @@ -135,22 +181,105 @@ def run_inference( raise gr.Error(f"Failed to save audio prompt: {write_e}") # 3. Run Generation - start_time = time.time() - - # Use torch.inference_mode() context manager for the generation call + + # With torch.inference_mode() context manager for the generation call with torch.inference_mode(): - output_audio_np = model.generate( - text_input, - max_tokens=max_new_tokens, - cfg_scale=cfg_scale, - temperature=temperature, - top_p=top_p, - use_cfg_filter=True, - cfg_filter_top_k=cfg_filter_top_k, # Pass the value here - use_torch_compile=False, # Keep False for Gradio stability - audio_prompt_path=prompt_path_for_generate, - ) + if use_streaming: + progress(0.05, desc="Starting streaming generation...") + + # Collect audio chunks + audio_chunks = [] + + # Stream generation with progress updates + total_chunks_estimate = (max_new_tokens // chunk_size) + 1 + chunk_count = 0 + + # Status message for streaming + if autoplay_chunks: + streaming_status = "Audio chunks will play as they're generated. The full audio will be available when generation completes." + else: + streaming_status = "Generating audio in streaming mode. The full audio will be available when generation completes." + + # Always yield all three outputs - even if empty for now + streaming_outputs = [None, streaming_status, ""] + yield streaming_outputs + + # Use the streaming generator + for i, audio_chunk in enumerate(model.generate_streaming( + text=text_input, + max_tokens=max_new_tokens, + cfg_scale=cfg_scale, + temperature=temperature, + top_p=top_p, + use_cfg_filter=True, + cfg_filter_top_k=cfg_filter_top_k, + use_torch_compile=False, # Keep False for Gradio stability + audio_prompt_path=prompt_path_for_generate, + chunk_size=chunk_size, + overlap=overlap + )): + # Store the chunk + audio_chunks.append(audio_chunk) + chunk_count += 1 + + # Update progress + progress_value = min(0.05 + (0.9 * chunk_count / total_chunks_estimate), 0.95) + progress(progress_value, desc=f"Generated chunk {chunk_count}...") + + # Process the audio segment if autoplay is enabled + if autoplay_chunks: + # Apply speed factor to the chunk + current_len = len(audio_chunk) + target_len = int(current_len / speed_factor) + + if target_len != current_len and target_len > 0: + x_original = np.arange(current_len) + x_resampled = np.linspace(0, current_len - 1, target_len) + resampled_chunk = np.interp(x_resampled, x_original, audio_chunk) + else: + resampled_chunk = audio_chunk + + # Create auto-playing HTML audio element + audio_element_html = create_audio_html(resampled_chunk) + + # Update status message with chunk info + streaming_status = f"Playing chunk {chunk_count} of approximately {total_chunks_estimate} (estimated). Full audio will be available when generation completes." + + # Yield partial results to update the UI + streaming_outputs = [None, streaming_status, audio_element_html] + yield streaming_outputs + else: + # Still need to yield updates even when not auto-playing + streaming_status = f"Generated chunk {chunk_count} of approximately {total_chunks_estimate} (estimated). Full audio will be available when generation completes." + streaming_outputs = [None, streaming_status, ""] + yield streaming_outputs + + # Combine all chunks + output_audio_np = np.concatenate(audio_chunks) + progress(0.95, desc="Processing final audio...") + + # Log completion + print(f"Streaming generation finished with {chunk_count} chunks") + streaming_status = f"Generation complete with {chunk_count} chunks. Total generation time: {time.time() - start_time:.2f} seconds." + # Always include the third parameter (empty HTML string) + streaming_outputs = [None, streaming_status, ""] + yield streaming_outputs + else: + progress(0.1, desc="Generating audio...") + # Use regular generation + output_audio_np = model.generate( + text=text_input, + max_tokens=max_new_tokens, + cfg_scale=cfg_scale, + temperature=temperature, + top_p=top_p, + use_cfg_filter=True, + cfg_filter_top_k=cfg_filter_top_k, + use_torch_compile=False, # Keep False for Gradio stability + audio_prompt_path=prompt_path_for_generate, + ) + progress(0.9, desc="Post-processing audio...") end_time = time.time() print(f"Generation finished in {end_time - start_time:.2f} seconds.") @@ -191,11 +320,25 @@ def run_inference( print( f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}" ) + + progress(1.0, desc="Done!") + + # Clear streaming status for final result + if use_streaming: + streaming_outputs = [output_audio, f"Generation complete in {end_time - start_time:.2f} seconds", ""] + yield streaming_outputs + else: + return output_audio else: print("\nGeneration finished, but no valid tokens were produced.") # Return default silence gr.Warning("Generation produced no output.") + if use_streaming: + streaming_outputs = [output_audio, "Generation produced no output", ""] + yield streaming_outputs + else: + return output_audio except Exception as e: print(f"Error during inference: {e}") @@ -203,7 +346,11 @@ def run_inference( traceback.print_exc() # Re-raise as Gradio error to display nicely in the UI - raise gr.Error(f"Inference failed: {e}") + error_message = f"Inference failed: {e}" + if use_streaming: + streaming_outputs = [None, error_message, ""] + yield streaming_outputs + raise gr.Error(error_message) finally: # 5. Cleanup Temporary Files defensively @@ -224,12 +371,20 @@ def run_inference( f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}" ) - return output_audio + if not use_streaming: + return output_audio # --- Create Gradio Interface --- css = """ #col-container {max-width: 90%; margin-left: auto; margin-right: auto;} +.streaming-audio-player { + margin-top: 10px; + padding: 10px; + border-radius: 8px; + background-color: #f0f9ff; + border: 1px solid #93c5fd; +} """ # Attempt to load default text from example.txt default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face." @@ -261,6 +416,24 @@ def run_inference( sources=["upload", "microphone"], type="numpy", ) + + with gr.Row(): + with gr.Column(scale=1): + streaming_toggle = gr.Checkbox( + label="Enable Streaming Generation", + value=False, + info="Generate audio in chunks for faster feedback" + ) + with gr.Column(scale=1): + autoplay_toggle = gr.Checkbox( + label="Auto-play Chunks", + value=True, + info="Play audio chunks as they're generated", + visible=False + ) + + run_button = gr.Button("Generate Audio", variant="primary") + with gr.Accordion("Generation Parameters", open=False): max_new_tokens = gr.Slider( label="Max New Tokens (Audio Length)", @@ -310,16 +483,46 @@ def run_inference( step=0.02, info="Adjusts the speed of the generated audio (1.0 = original speed).", ) - - run_button = gr.Button("Generate Audio", variant="primary") + + with gr.Accordion("Streaming Parameters", open=False, visible=False) as streaming_accordion: + chunk_size_slider = gr.Slider( + label="Chunk Size", + minimum=50, + maximum=300, + value=100, + step=10, + info="Number of tokens to generate per chunk (smaller = faster first output, but may have more artifacts).", + ) + overlap_slider = gr.Slider( + label="Chunk Overlap", + minimum=0, + maximum=30, + value=10, + step=5, + info="Overlap between consecutive chunks (higher = smoother transitions, but slower).", + ) with gr.Column(scale=1): audio_output = gr.Audio( label="Generated Audio", type="numpy", autoplay=False, + elem_id="main-audio-output" + ) + status_output = gr.Markdown("Ready to generate audio") + streaming_player = gr.HTML( + visible=False, + elem_classes="streaming-audio-player", + elem_id="streaming-player" ) + # Make streaming parameters and autoplay toggle visible only when streaming is enabled + streaming_toggle.change( + fn=lambda x: [gr.update(visible=x), gr.update(visible=x)], + inputs=[streaming_toggle], + outputs=[streaming_accordion, autoplay_toggle] + ) + # Link button click to function run_button.click( fn=run_inference, @@ -332,8 +535,12 @@ def run_inference( top_p, cfg_filter_top_k, speed_factor_slider, + streaming_toggle, + autoplay_toggle, + chunk_size_slider, + overlap_slider, ], - outputs=[audio_output], # Add status_output here if using it + outputs=[audio_output, status_output, streaming_player], api_name="generate_audio", ) @@ -349,6 +556,10 @@ def run_inference( 0.95, 35, 0.94, + False, # Streaming off + True, # Autoplay on + 100, + 10, ], [ "[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.", @@ -359,6 +570,10 @@ def run_inference( 0.95, 35, 0.94, + True, # Streaming on + True, # Autoplay on + 100, + 10, ], ] @@ -374,8 +589,12 @@ def run_inference( top_p, cfg_filter_top_k, speed_factor_slider, + streaming_toggle, + autoplay_toggle, + chunk_size_slider, + overlap_slider, ], - outputs=[audio_output], + outputs=[audio_output, status_output, streaming_player], fn=run_inference, cache_examples=False, label="Examples (Click to Run)", diff --git a/dia/model.py b/dia/model.py index bdaeebc2..478bfb3e 100644 --- a/dia/model.py +++ b/dia/model.py @@ -3,8 +3,9 @@ import torch import torchaudio from huggingface_hub import hf_hub_download +from typing import Generator, List, Tuple, Optional, Union -from .audio import audio_to_codebook, codebook_to_audio +from .audio import audio_to_codebook, codebook_to_audio, build_revert_indices, revert_audio_delay, decode from .config import DiaConfig from .layers import DiaModel, KVCache @@ -208,6 +209,60 @@ def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, to return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask + def _stream_tokens_to_audio( + self, + tokens: torch.Tensor, + delay_pattern: List[int] + ) -> np.ndarray: + """ + Process a chunk of tokens to generate streaming audio. + + Args: + tokens: Tensor of shape [1, T, C] containing audio tokens + delay_pattern: List of delays for each channel + + Returns: + numpy array of audio samples + """ + num_channels = tokens.shape[2] + seq_length = tokens.shape[1] + + # Build revert indices for the delay pattern + t_idx_BxTxC, indices_BTCx3 = build_revert_indices( + B=1, + T=seq_length, + C=num_channels, + delay_pattern=delay_pattern + ) + + # Apply the revert operation to get the original tokens + reverted_tokens = revert_audio_delay( + audio_BxTxC=tokens, + pad_value=0, + precomp=(t_idx_BxTxC, indices_BTCx3), + T=seq_length, + ) + + # Transpose to [1, C, T] for the DAC model + codebook = reverted_tokens.transpose(1, 2) + + # Validate token range and clamp if needed + min_valid_index = 0 + max_valid_index = 1023 + invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index) + + num_invalid = torch.sum(invalid_mask).item() + if num_invalid > 0: + print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.") + + # Set invalid values to 0 + codebook[invalid_mask] = 0 + + # Decode the audio tokens to waveform + audio_array = decode(self.dac_model, codebook) + + return audio_array.squeeze().cpu().numpy() + @torch.inference_mode() def generate( self, @@ -438,3 +493,285 @@ def generate( generated_codes.transpose(1, 0), self.dac_model, delay_pattern, B=1, T=max_tokens, C=num_channels ) return audio.squeeze().cpu().numpy() + + @torch.inference_mode() + def generate_streaming( + self, + text: str, + max_tokens: int | None = None, + cfg_scale: float = 3.0, + temperature: float = 1.3, + top_p: float = 0.95, + use_cfg_filter: bool = True, + use_torch_compile: bool = False, + cfg_filter_top_k: int = 35, + audio_prompt_path: str | None = None, + chunk_size: int = 100, # Number of tokens to generate before yielding audio + overlap: int = 10, # Overlap between chunks to prevent boundary artifacts + ) -> Generator[np.ndarray, None, None]: + """ + Generates audio from a text prompt in a streaming fashion, yielding chunks of audio as they're generated. + + Args: + text: The text prompt to generate audio from. + max_tokens: Maximum number of tokens to generate. + cfg_scale: Classifier-free guidance scale. + temperature: Sampling temperature. + top_p: Top-p sampling parameter. + use_cfg_filter: Whether to use classifier-free guidance filtering. + use_torch_compile: Whether to use torch.compile for the decode step. + cfg_filter_top_k: Number of top-k tokens to consider for CFG filtering. + audio_prompt_path: Optional path to an audio prompt. + chunk_size: Number of tokens to generate before yielding audio. + overlap: Overlap between chunks to prevent boundary artifacts. + + Yields: + Chunks of generated audio as numpy arrays. + """ + num_channels = self.config.data.channels + audio_bos_value = self.config.data.audio_bos_value + audio_eos_value = self.config.data.audio_eos_value + audio_pad_value = self.config.data.audio_pad_value + delay_pattern = self.config.data.delay_pattern + max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens + delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device) + max_delay_pattern = max(delay_pattern) + self.model.eval() + + # Minimum number of tokens needed before we can start producing audio + # Accounts for the delay pattern to ensure we have enough context + min_tokens_before_streaming = max_delay_pattern + 1 + + # Setup for text input encoding + ( + cond_src_BxS, + cond_src_positions_BxS, + cond_src_padding_mask_BxS, + cond_enc_self_attn_mask_Bx1xSxS, + ) = self._prepare_text_input(text) + + unc_src_BxS = torch.zeros_like(cond_src_BxS) + src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0) + src_positions_BxS = cond_src_positions_BxS.expand(2, -1) + src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1) + enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1) + + # Encoder Pass + encoder_out = self.model.encoder( + x_ids=src_BxS, + src_positions=src_positions_BxS, + deterministic=True, + attn_mask=enc_self_attn_mask_Bx1xSxS, + ) # Shape: (B, S, E) + + # Prepare Decoder Inputs + # Allocate KV Cache + decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv( + max_tokens, encoder_out, src_positions_BxS + ) + + decoder_self_attention_cache: list[KVCache] = [] + for _ in range(self.model.decoder.num_layers): + decoder_self_attention_cache.append( + KVCache( + self.config.model.decoder.gqa_query_heads, + max_tokens, + self.config.model.decoder.gqa_head_dim, + self.device, + ) + ) + + # Initialize decoder inputs with BOS token + generated_BxTxC = torch.full( + (2, 1, num_channels), + fill_value=audio_bos_value, + dtype=torch.long, + device=self.device, + ) + + current_step = 0 + prompt_len_inc_bos = 1 # Start with BOS length + + # Handle audio prompt if provided + if audio_prompt_path is not None: + audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T + if sr != 44100: # Resample to 44.1kHz + audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100) + audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T + audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data) + generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1) + + prefill_len = generated_BxTxC.shape[1] + prompt_len_inc_bos = prefill_len + prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1) + prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2) + + prefill_self_attn_mask = self._create_attn_mask( + prefill_tgt_padding_mask, + prefill_tgt_padding_mask, + is_causal=True, + ) + prefill_cross_attn_mask = self._create_attn_mask( + prefill_tgt_padding_mask, + src_padding_mask_BxS, + is_causal=False, + ) + + _ = self.model.decoder.forward( + tgt_ids_BxTxC=generated_BxTxC, + encoder_out=encoder_out, + tgt_positions=prefill_tgt_pos, + src_positions=src_positions_BxS, + deterministic=True, + self_attn_mask=prefill_self_attn_mask, + cross_attn_mask=prefill_cross_attn_mask, + self_attention_cache=decoder_self_attention_cache, + cross_attention_cache=decoder_cross_attention_cache, + ) + + current_step = prefill_len - 1 + + # Setup for autoregressive generation + decode_step = self.model.decoder.decode_step + if use_torch_compile: + decode_step = torch.compile( + self.model.decoder.decode_step, + mode="default", + ) + + tgt_padding_mask = ( + (generated_BxTxC[:, -1, :].unsqueeze(1) != audio_pad_value).any(dim=2).to(self.device) + ) # [B, 1] + # Generated tokens are never PAD, so we use fixed mask + decoder_cross_attn_mask = self._create_attn_mask( + tgt_padding_mask, # Query mask [B, 1] + src_padding_mask_BxS, # Key mask [B, S] + is_causal=False, + ) # [B, 1, 1, S] + + # Storage for streaming + streaming_tokens = [] + token_count = 0 + chunk_count = 0 + eos_detected_channel_0 = False + eos_countdown = -1 + extra_steps_after_eos = 30 + + # Autoregressive generation loop + for step in range(current_step, current_step + max_tokens): + # Handle current token position in generation + position_in_output = step - current_step + prompt_len_inc_bos + + # Make sure there's enough room in our generated tensor + if step + 1 >= generated_BxTxC.shape[1]: + # Extend the tensor if needed + generated_BxTxC = torch.cat( + [ + generated_BxTxC, + torch.full( + (2, chunk_size, num_channels), + fill_value=-1, + dtype=torch.long, + device=self.device, + ), + ], + dim=1, + ) + + # Get the current tokens as input to the decoder + tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1) + tgt_pos_Bx1 = torch.full( + (2, 1), + fill_value=step, + dtype=torch.long, + device=self.device, + ) + + # Generate next token + logits_Bx1xCxV, new_cache = decode_step( + tgt_ids_Bx1xC=tgt_ids_Bx1xC, + tgt_pos_Bx1=tgt_pos_Bx1, + encoder_out=encoder_out, + self_attn_mask=None, + cross_attn_mask=decoder_cross_attn_mask, + self_attention_cache=decoder_self_attention_cache, + cross_attention_cache=decoder_cross_attention_cache, + ) + + # Update KV cache + for i, layer_cache in enumerate(decoder_self_attention_cache): + layer_cache.update_cache(new_cache[i][0], new_cache[i][1]) + + # Sample the next token using classifier-free guidance + V = self.config.model.tgt_vocab_size + logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] # B, C, V + uncond_logits_CxV = logits_last_BxCxV[0, :, :] + cond_logits_CxV = logits_last_BxCxV[1, :, :] + + cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV) + + logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V + logits_CxV[:, 1025:] = -torch.inf + + # Sample next token + pred_C = _sample_next_token( + logits_CxV.float(), + temperature=temperature, + top_p=top_p, + use_cfg_filter=use_cfg_filter, + cfg_filter_top_k=cfg_filter_top_k, + ) + + # Handle delay pattern for prompts + generation_step_index = step - current_step + if audio_prompt_path is None: + pred_C = torch.where( + generation_step_index >= delay_tensor, + pred_C, + audio_bos_value, + ) + + # Store the generated token + generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1) + + # Add to streaming buffer + streaming_tokens.append(pred_C.clone()) + token_count += 1 + + # EOS handling + if not eos_detected_channel_0 and pred_C[0] == audio_eos_value: + eos_detected_channel_0 = True + eos_countdown = extra_steps_after_eos + + if eos_countdown > 0: + step_after_eos = max_delay_pattern - eos_countdown + for i, d in enumerate(delay_pattern): + if step_after_eos == d: + generated_BxTxC[:, step + 1, i] = audio_eos_value + elif step_after_eos > d: + generated_BxTxC[:, step + 1, i] = audio_pad_value + eos_countdown -= 1 + if eos_countdown == 0: + # Process any remaining tokens before exiting + if streaming_tokens: + chunk_tensor = torch.stack(streaming_tokens, dim=0).unsqueeze(0) # [1, T, C] + yield self._stream_tokens_to_audio(chunk_tensor, delay_pattern) + break + + # Process a chunk once we have enough tokens + if token_count >= chunk_size and token_count >= min_tokens_before_streaming: + # If we have enough tokens, create a chunk + chunk_tensor = torch.stack(streaming_tokens[:chunk_size-overlap], dim=0).unsqueeze(0) # [1, T, C] + + # Convert to audio and yield + yield self._stream_tokens_to_audio(chunk_tensor, delay_pattern) + + # Keep overlap tokens for next chunk + streaming_tokens = streaming_tokens[chunk_size-overlap:] + token_count = len(streaming_tokens) + chunk_count += 1 + + # Process any remaining tokens + if streaming_tokens and not eos_detected_channel_0: + chunk_tensor = torch.stack(streaming_tokens, dim=0).unsqueeze(0) # [1, T, C] + yield self._stream_tokens_to_audio(chunk_tensor, delay_pattern) diff --git a/test_streaming.py b/test_streaming.py new file mode 100644 index 00000000..2017a418 --- /dev/null +++ b/test_streaming.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Test script for the Dia model with streaming audio generation. +This script demonstrates the difference between regular generation and streaming generation. +""" + +import argparse +import time +import os +import wave +import numpy as np +import pygame +from dia import Dia +from dia.config import DiaConfig + +def write_audio_to_wav(audio_data, filename, sample_rate=44100): + """Write audio data to a WAV file.""" + with wave.open(filename, 'wb') as wav_file: + # Set parameters + wav_file.setnchannels(1) # Mono + wav_file.setsampwidth(2) # 2 bytes (16 bits) per sample + wav_file.setframerate(sample_rate) + + # Convert float32 audio to int16 + audio_int16 = (audio_data * 32767).astype(np.int16) + wav_file.writeframes(audio_int16.tobytes()) + + print(f"Saved audio to {filename}") + +def play_audio(audio_data, sample_rate=44100): + """Play audio using pygame.""" + pygame.mixer.init(frequency=sample_rate, size=-16, channels=1) + # Convert audio to int16 + audio_int16 = (audio_data * 32767).astype(np.int16) + sound = pygame.sndarray.make_sound(audio_int16) + sound.play() + pygame.time.wait(int(1000 * len(audio_data) / sample_rate)) + +def test_regular_generation(model, text, max_tokens=1000): + """Test regular (non-streaming) audio generation.""" + print(f"Generating audio for text: '{text}'") + start_time = time.time() + audio = model.generate( + text=text, + max_tokens=max_tokens, + cfg_scale=3.0, + temperature=1.3, + top_p=0.95, + ) + total_time = time.time() - start_time + + print(f"Generated {len(audio)} audio samples in {total_time:.2f} seconds") + return audio + +def test_streaming_generation(model, text, max_tokens=1000, chunk_size=100, overlap=10, play=False): + """Test streaming audio generation.""" + print(f"Generating streaming audio for text: '{text}'") + print(f"Chunk size: {chunk_size}, Overlap: {overlap}") + + audio_chunks = [] + start_time = time.time() + chunk_times = [] + + # Use the streaming generator + for i, audio_chunk in enumerate(model.generate_streaming( + text=text, + max_tokens=max_tokens, + cfg_scale=3.0, + temperature=1.3, + top_p=0.95, + chunk_size=chunk_size, + overlap=overlap + )): + chunk_time = time.time() - start_time + chunk_times.append(chunk_time) + + # Store the chunk + audio_chunks.append(audio_chunk) + + print(f"Received chunk {i+1}: {len(audio_chunk)} samples after {chunk_time:.2f}s") + + # Optionally play the audio chunk in real-time + if play: + play_audio(audio_chunk) + + # Reset start time to measure individual chunk generation time + start_time = time.time() + + # Combine all chunks + # The full audio stream should ideally be handled in real-time in a real application + full_audio = np.concatenate(audio_chunks) + + print(f"Total audio length: {len(full_audio)} samples") + print(f"Generated {len(chunk_times)} chunks with average time {sum(chunk_times)/len(chunk_times):.2f}s per chunk") + + return full_audio, audio_chunks + +def main(): + parser = argparse.ArgumentParser(description="Test Dia model with streaming audio generation") + parser.add_argument("--text", type=str, default="Hello world, this is a test of streaming audio generation.", + help="Text prompt to generate audio from") + parser.add_argument("--model", type=str, default="nari-labs/Dia-1.6B", + help="Model name or path to local model files") + parser.add_argument("--config", type=str, default=None, + help="Path to config.json (only needed if using local model)") + parser.add_argument("--checkpoint", type=str, default=None, + help="Path to checkpoint (only needed if using local model)") + parser.add_argument("--max_tokens", type=int, default=1000, + help="Maximum number of tokens to generate") + parser.add_argument("--chunk_size", type=int, default=100, + help="Number of tokens to generate before yielding audio (for streaming)") + parser.add_argument("--overlap", type=int, default=10, + help="Overlap between chunks to prevent boundary artifacts (for streaming)") + parser.add_argument("--streaming", action="store_true", + help="Use streaming generation mode") + parser.add_argument("--play", action="store_true", + help="Play audio chunks as they're generated (for streaming)") + parser.add_argument("--output", type=str, default="output.wav", + help="Output file for the generated audio") + parser.add_argument("--output_dir", type=str, default="chunks", + help="Directory to save individual audio chunks (for streaming)") + + args = parser.parse_args() + + # Initialize model + print("Loading model...") + if args.config and args.checkpoint: + model = Dia.from_local(args.config, args.checkpoint) + else: + model = Dia.from_pretrained(args.model) + + print("Model loaded successfully!") + + # Test based on mode + if args.streaming: + # Create output directory for chunks if needed + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + # Run streaming test + full_audio, audio_chunks = test_streaming_generation( + model, + args.text, + max_tokens=args.max_tokens, + chunk_size=args.chunk_size, + overlap=args.overlap, + play=args.play + ) + + # Save full audio + write_audio_to_wav(full_audio, args.output) + + # Save individual chunks + for i, chunk in enumerate(audio_chunks): + chunk_file = os.path.join(args.output_dir, f"chunk_{i+1}.wav") + write_audio_to_wav(chunk, chunk_file) + else: + # Run regular test + audio = test_regular_generation(model, args.text, max_tokens=args.max_tokens) + write_audio_to_wav(audio, args.output) + + # Play the audio if requested + if args.play: + print("Playing generated audio...") + play_audio(audio) + + print("Done!") + +if __name__ == "__main__": + main() \ No newline at end of file