Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 240 additions & 21 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import argparse
import tempfile
import time
import queue
import threading
import os
import io
import base64
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Generator, Dict, Any

import gradio as gr
import numpy as np
Expand Down Expand Up @@ -46,6 +51,32 @@
raise


def create_audio_html(audio_data, sample_rate=44100):
"""
Create an HTML audio element with the audio data
"""
# Convert audio to 16-bit PCM WAV
audio_int16 = (audio_data * 32767).astype(np.int16)

# Create in-memory WAV file
wav_bytes = io.BytesIO()
with sf.SoundFile(wav_bytes, mode='w', samplerate=sample_rate, channels=1, format='WAV', subtype='PCM_16') as f:
f.write(audio_int16)

# Encode as base64
wav_bytes.seek(0)
base64_audio = base64.b64encode(wav_bytes.read()).decode('utf-8')

# Create HTML audio element with autoplay
audio_html = f"""
<audio autoplay controls style="width: 100%;">
<source src="data:audio/wav;base64,{base64_audio}" type="audio/wav">
Your browser does not support the audio element.
</audio>
"""
return audio_html


def run_inference(
text_input: str,
audio_prompt_input: Optional[Tuple[int, np.ndarray]],
Expand All @@ -55,6 +86,11 @@ def run_inference(
top_p: float,
cfg_filter_top_k: int,
speed_factor: float,
use_streaming: bool = False,
autoplay_chunks: bool = False,
chunk_size: int = 100,
overlap: int = 10,
progress=gr.Progress(),
):
"""
Runs Nari inference using the globally loaded model and provided inputs.
Expand All @@ -64,10 +100,20 @@ def run_inference(

if not text_input or text_input.isspace():
raise gr.Error("Text input cannot be empty.")

# Show different progress message based on mode
if use_streaming:
progress(0, desc="Preparing streaming generation...")
else:
progress(0, desc="Preparing generation...")

temp_txt_file_path = None
temp_audio_prompt_path = None
output_audio = (44100, np.zeros(1, dtype=np.float32))

# For streaming with real-time playback
audio_element_html = ""
streaming_outputs = []

try:
prompt_path_for_generate = None
Expand Down Expand Up @@ -135,22 +181,105 @@ def run_inference(
raise gr.Error(f"Failed to save audio prompt: {write_e}")

# 3. Run Generation

start_time = time.time()

# Use torch.inference_mode() context manager for the generation call
# With torch.inference_mode() context manager for the generation call
with torch.inference_mode():
output_audio_np = model.generate(
text_input,
max_tokens=max_new_tokens,
cfg_scale=cfg_scale,
temperature=temperature,
top_p=top_p,
use_cfg_filter=True,
cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
use_torch_compile=False, # Keep False for Gradio stability
audio_prompt_path=prompt_path_for_generate,
)
if use_streaming:
progress(0.05, desc="Starting streaming generation...")

# Collect audio chunks
audio_chunks = []

# Stream generation with progress updates
total_chunks_estimate = (max_new_tokens // chunk_size) + 1
chunk_count = 0

# Status message for streaming
if autoplay_chunks:
streaming_status = "Audio chunks will play as they're generated. The full audio will be available when generation completes."
else:
streaming_status = "Generating audio in streaming mode. The full audio will be available when generation completes."

# Always yield all three outputs - even if empty for now
streaming_outputs = [None, streaming_status, ""]
yield streaming_outputs

# Use the streaming generator
for i, audio_chunk in enumerate(model.generate_streaming(
text=text_input,
max_tokens=max_new_tokens,
cfg_scale=cfg_scale,
temperature=temperature,
top_p=top_p,
use_cfg_filter=True,
cfg_filter_top_k=cfg_filter_top_k,
use_torch_compile=False, # Keep False for Gradio stability
audio_prompt_path=prompt_path_for_generate,
chunk_size=chunk_size,
overlap=overlap
)):
# Store the chunk
audio_chunks.append(audio_chunk)
chunk_count += 1

# Update progress
progress_value = min(0.05 + (0.9 * chunk_count / total_chunks_estimate), 0.95)
progress(progress_value, desc=f"Generated chunk {chunk_count}...")

# Process the audio segment if autoplay is enabled
if autoplay_chunks:
# Apply speed factor to the chunk
current_len = len(audio_chunk)
target_len = int(current_len / speed_factor)

if target_len != current_len and target_len > 0:
x_original = np.arange(current_len)
x_resampled = np.linspace(0, current_len - 1, target_len)
resampled_chunk = np.interp(x_resampled, x_original, audio_chunk)
else:
resampled_chunk = audio_chunk

# Create auto-playing HTML audio element
audio_element_html = create_audio_html(resampled_chunk)

# Update status message with chunk info
streaming_status = f"Playing chunk {chunk_count} of approximately {total_chunks_estimate} (estimated). Full audio will be available when generation completes."

# Yield partial results to update the UI
streaming_outputs = [None, streaming_status, audio_element_html]
yield streaming_outputs
else:
# Still need to yield updates even when not auto-playing
streaming_status = f"Generated chunk {chunk_count} of approximately {total_chunks_estimate} (estimated). Full audio will be available when generation completes."
streaming_outputs = [None, streaming_status, ""]
yield streaming_outputs

# Combine all chunks
output_audio_np = np.concatenate(audio_chunks)
progress(0.95, desc="Processing final audio...")

# Log completion
print(f"Streaming generation finished with {chunk_count} chunks")
streaming_status = f"Generation complete with {chunk_count} chunks. Total generation time: {time.time() - start_time:.2f} seconds."
# Always include the third parameter (empty HTML string)
streaming_outputs = [None, streaming_status, ""]
yield streaming_outputs
else:
progress(0.1, desc="Generating audio...")
# Use regular generation
output_audio_np = model.generate(
text=text_input,
max_tokens=max_new_tokens,
cfg_scale=cfg_scale,
temperature=temperature,
top_p=top_p,
use_cfg_filter=True,
cfg_filter_top_k=cfg_filter_top_k,
use_torch_compile=False, # Keep False for Gradio stability
audio_prompt_path=prompt_path_for_generate,
)
progress(0.9, desc="Post-processing audio...")

end_time = time.time()
print(f"Generation finished in {end_time - start_time:.2f} seconds.")
Expand Down Expand Up @@ -191,19 +320,37 @@ def run_inference(
print(
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
)

progress(1.0, desc="Done!")

# Clear streaming status for final result
if use_streaming:
streaming_outputs = [output_audio, f"Generation complete in {end_time - start_time:.2f} seconds", ""]
yield streaming_outputs
else:
return output_audio

else:
print("\nGeneration finished, but no valid tokens were produced.")
# Return default silence
gr.Warning("Generation produced no output.")
if use_streaming:
streaming_outputs = [output_audio, "Generation produced no output", ""]
yield streaming_outputs
else:
return output_audio

except Exception as e:
print(f"Error during inference: {e}")
import traceback

traceback.print_exc()
# Re-raise as Gradio error to display nicely in the UI
raise gr.Error(f"Inference failed: {e}")
error_message = f"Inference failed: {e}"
if use_streaming:
streaming_outputs = [None, error_message, ""]
yield streaming_outputs
raise gr.Error(error_message)

finally:
# 5. Cleanup Temporary Files defensively
Expand All @@ -224,12 +371,20 @@ def run_inference(
f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}"
)

return output_audio
if not use_streaming:
return output_audio


# --- Create Gradio Interface ---
css = """
#col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
.streaming-audio-player {
margin-top: 10px;
padding: 10px;
border-radius: 8px;
background-color: #f0f9ff;
border: 1px solid #93c5fd;
}
"""
# Attempt to load default text from example.txt
default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face."
Expand Down Expand Up @@ -261,6 +416,24 @@ def run_inference(
sources=["upload", "microphone"],
type="numpy",
)

with gr.Row():
with gr.Column(scale=1):
streaming_toggle = gr.Checkbox(
label="Enable Streaming Generation",
value=False,
info="Generate audio in chunks for faster feedback"
)
with gr.Column(scale=1):
autoplay_toggle = gr.Checkbox(
label="Auto-play Chunks",
value=True,
info="Play audio chunks as they're generated",
visible=False
)

run_button = gr.Button("Generate Audio", variant="primary")

with gr.Accordion("Generation Parameters", open=False):
max_new_tokens = gr.Slider(
label="Max New Tokens (Audio Length)",
Expand Down Expand Up @@ -310,16 +483,46 @@ def run_inference(
step=0.02,
info="Adjusts the speed of the generated audio (1.0 = original speed).",
)

run_button = gr.Button("Generate Audio", variant="primary")

with gr.Accordion("Streaming Parameters", open=False, visible=False) as streaming_accordion:
chunk_size_slider = gr.Slider(
label="Chunk Size",
minimum=50,
maximum=300,
value=100,
step=10,
info="Number of tokens to generate per chunk (smaller = faster first output, but may have more artifacts).",
)
overlap_slider = gr.Slider(
label="Chunk Overlap",
minimum=0,
maximum=30,
value=10,
step=5,
info="Overlap between consecutive chunks (higher = smoother transitions, but slower).",
)

with gr.Column(scale=1):
audio_output = gr.Audio(
label="Generated Audio",
type="numpy",
autoplay=False,
elem_id="main-audio-output"
)
status_output = gr.Markdown("Ready to generate audio")
streaming_player = gr.HTML(
visible=False,
elem_classes="streaming-audio-player",
elem_id="streaming-player"
)

# Make streaming parameters and autoplay toggle visible only when streaming is enabled
streaming_toggle.change(
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
inputs=[streaming_toggle],
outputs=[streaming_accordion, autoplay_toggle]
)

# Link button click to function
run_button.click(
fn=run_inference,
Expand All @@ -332,8 +535,12 @@ def run_inference(
top_p,
cfg_filter_top_k,
speed_factor_slider,
streaming_toggle,
autoplay_toggle,
chunk_size_slider,
overlap_slider,
],
outputs=[audio_output], # Add status_output here if using it
outputs=[audio_output, status_output, streaming_player],
api_name="generate_audio",
)

Expand All @@ -349,6 +556,10 @@ def run_inference(
0.95,
35,
0.94,
False, # Streaming off
True, # Autoplay on
100,
10,
],
[
"[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.",
Expand All @@ -359,6 +570,10 @@ def run_inference(
0.95,
35,
0.94,
True, # Streaming on
True, # Autoplay on
100,
10,
],
]

Expand All @@ -374,8 +589,12 @@ def run_inference(
top_p,
cfg_filter_top_k,
speed_factor_slider,
streaming_toggle,
autoplay_toggle,
chunk_size_slider,
overlap_slider,
],
outputs=[audio_output],
outputs=[audio_output, status_output, streaming_player],
fn=run_inference,
cache_examples=False,
label="Examples (Click to Run)",
Expand Down
Loading