Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 32 additions & 123 deletions acestep/ui/gradio/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,27 @@
from . import generation_handlers as gen_h
from . import results_handlers as res_h
from . import training_handlers as train_h
from .wiring import (
GenerationWiringContext,
TrainingWiringContext,
build_auto_checkbox_inputs,
build_auto_checkbox_outputs,
build_mode_ui_outputs,
)
from acestep.ui.gradio.i18n import t


def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
"""Setup event handlers connecting UI components and business logic"""
wiring_context = GenerationWiringContext(
demo=demo,
dit_handler=dit_handler,
llm_handler=llm_handler,
dataset_handler=dataset_handler,
dataset_section=dataset_section,
generation_section=generation_section,
results_section=results_section,
)

# ========== Dataset Handlers ==========
dataset_section["import_dataset_btn"].click(
Expand Down Expand Up @@ -190,25 +206,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
)

# Auto-checkbox outputs used by .then() chains after events that populate fields
_auto_checkbox_outputs = [
generation_section["bpm_auto"],
generation_section["key_auto"],
generation_section["timesig_auto"],
generation_section["vocal_lang_auto"],
generation_section["duration_auto"],
generation_section["bpm"],
generation_section["key_scale"],
generation_section["time_signature"],
generation_section["vocal_language"],
generation_section["audio_duration"],
]
_auto_checkbox_inputs = [
generation_section["bpm"],
generation_section["key_scale"],
generation_section["time_signature"],
generation_section["vocal_language"],
generation_section["audio_duration"],
]
_auto_checkbox_outputs = build_auto_checkbox_outputs(wiring_context)
_auto_checkbox_inputs = build_auto_checkbox_inputs(wiring_context)
_mode_ui_outputs = build_mode_ui_outputs(wiring_context)

# ========== Audio Conversion (LM Codes Hints accordion in Custom mode) ==========
generation_section["convert_src_to_codes_btn"].click(
Expand Down Expand Up @@ -418,58 +418,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["generation_mode"],
generation_section["previous_generation_mode"],
],
outputs=[
generation_section["simple_mode_group"],
generation_section["custom_mode_group"],
generation_section["generate_btn"],
generation_section["simple_sample_created"],
generation_section["optional_params_accordion"],
generation_section["task_type"],
generation_section["src_audio_row"],
generation_section["repainting_group"],
generation_section["text2music_audio_codes_group"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["generate_btn_row"],
generation_section["generation_mode"],
generation_section["results_wrapper"],
generation_section["think_checkbox"],
generation_section["load_file_col"],
generation_section["load_file"],
generation_section["audio_cover_strength"],
generation_section["cover_noise_strength"],
# Extract/Lego-mode outputs (indices 19-29)
generation_section["captions"],
generation_section["lyrics"],
generation_section["bpm"],
generation_section["key_scale"],
generation_section["time_signature"],
generation_section["vocal_language"],
generation_section["audio_duration"],
generation_section["auto_score"],
generation_section["autogen_checkbox"],
generation_section["auto_lrc"],
generation_section["analyze_btn"],
# Dynamic repainting/stem labels (indices 30-32)
generation_section["repainting_header_html"],
generation_section["repainting_start"],
generation_section["repainting_end"],
# Previous mode state (index 33)
generation_section["previous_generation_mode"],
# Mode-specific help button groups (indices 34-36)
generation_section["remix_help_group"],
generation_section["extract_help_group"],
generation_section["complete_help_group"],
# Auto checkbox updates (indices 37-41)
generation_section["bpm_auto"],
generation_section["key_auto"],
generation_section["timesig_auto"],
generation_section["vocal_lang_auto"],
generation_section["duration_auto"],
# State-leakage fix: clear stale values on mode switch (indices 42-43)
generation_section["text2music_audio_code_string"],
generation_section["src_audio"],
]
outputs=_mode_ui_outputs
)

# ========== Extract Mode: Auto-fill caption from track_name ==========
Expand Down Expand Up @@ -687,58 +636,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
# ========== Send to Remix / Repaint Handlers ==========
# Mode-UI outputs shared with generation_mode.change — applied atomically
# so we don't rely on a chained .change() event for visibility/label updates.
_mode_ui_outputs = [
generation_section["simple_mode_group"],
generation_section["custom_mode_group"],
generation_section["generate_btn"],
generation_section["simple_sample_created"],
generation_section["optional_params_accordion"],
generation_section["task_type"],
generation_section["src_audio_row"],
generation_section["repainting_group"],
generation_section["text2music_audio_codes_group"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["generate_btn_row"],
generation_section["generation_mode"],
generation_section["results_wrapper"],
generation_section["think_checkbox"],
generation_section["load_file_col"],
generation_section["load_file"],
generation_section["audio_cover_strength"],
generation_section["cover_noise_strength"],
# Extract/Lego-mode outputs (indices 19-29)
generation_section["captions"],
generation_section["lyrics"],
generation_section["bpm"],
generation_section["key_scale"],
generation_section["time_signature"],
generation_section["vocal_language"],
generation_section["audio_duration"],
generation_section["auto_score"],
generation_section["autogen_checkbox"],
generation_section["auto_lrc"],
generation_section["analyze_btn"],
# Dynamic repainting/stem labels (indices 30-32)
generation_section["repainting_header_html"],
generation_section["repainting_start"],
generation_section["repainting_end"],
# Previous mode state (index 33)
generation_section["previous_generation_mode"],
# Mode-specific help button groups (indices 34-36)
generation_section["remix_help_group"],
generation_section["extract_help_group"],
generation_section["complete_help_group"],
# Auto checkbox updates (indices 37-41)
generation_section["bpm_auto"],
generation_section["key_auto"],
generation_section["timesig_auto"],
generation_section["vocal_lang_auto"],
generation_section["duration_auto"],
# State-leakage fix: clear stale values on mode switch (indices 42-43)
generation_section["text2music_audio_code_string"],
generation_section["src_audio"],
]
for btn_idx in range(1, 9):
results_section[f"send_to_remix_btn_{btn_idx}"].click(
fn=lambda audio, lm, ly, cap, cur_mode: res_h.send_audio_to_remix(
Expand Down Expand Up @@ -778,6 +675,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
# ========== Score Calculation Handlers ==========
# Use default argument to capture btn_idx value at definition time (Python closure fix)
def make_score_handler(idx):
"""Build a score-click callback bound to a fixed result slot index."""
return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
dit_handler, llm_handler, idx, scale, batch_idx, queue
)
Expand All @@ -800,6 +698,7 @@ def make_score_handler(idx):
# ========== LRC Timestamp Handlers ==========
# Use default argument to capture btn_idx value at definition time (Python closure fix)
def make_lrc_handler(idx):
"""Build an LRC-generation callback bound to a fixed result slot index."""
return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
)
Expand Down Expand Up @@ -842,6 +741,7 @@ def make_lrc_handler(idx):
)

def generation_wrapper(*args):
"""Proxy batched generation to the results handler stream."""
yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
# ========== Generation Handler ==========
generation_section["generate_btn"].click(
Expand Down Expand Up @@ -1239,6 +1139,13 @@ def generation_wrapper(*args):

def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
"""Setup event handlers for the training tab (dataset builder and LoRA training)"""
training_context = TrainingWiringContext(
demo=demo,
dit_handler=dit_handler,
llm_handler=llm_handler,
training_section=training_section,
)
training_section = training_context.training_section

# ========== Load Existing Dataset (Top Section) ==========

Expand Down Expand Up @@ -1516,6 +1423,7 @@ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_secti

# Start training from preprocessed tensors
def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts):
"""Stream LoRA training progress and normalize failure outputs for the UI."""
from loguru import logger
if not isinstance(ts, dict):
ts = {"is_training": False, "should_stop": False}
Expand Down Expand Up @@ -1588,6 +1496,7 @@ def lokr_training_wrapper(
tensor_dir, ldim, lalpha, factor, decompose_both, use_tucker,
use_scalar, weight_decompose, lr, ep, bs, ga, se, sh, sd, od, ts,
):
"""Stream LoKr training progress and normalize failure outputs for the UI."""
from loguru import logger
if not isinstance(ts, dict):
ts = {"is_training": False, "should_stop": False}
Expand Down
21 changes: 21 additions & 0 deletions acestep/ui/gradio/events/wiring/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Wiring helpers for Gradio event registration.

This package provides shared context and list-builder helpers used by the
event wiring facade in ``acestep.ui.gradio.events``.
"""

from .context import (
GenerationWiringContext,
TrainingWiringContext,
build_auto_checkbox_inputs,
build_auto_checkbox_outputs,
build_mode_ui_outputs,
)

__all__ = [
"GenerationWiringContext",
"TrainingWiringContext",
"build_auto_checkbox_inputs",
"build_auto_checkbox_outputs",
"build_mode_ui_outputs",
]
124 changes: 124 additions & 0 deletions acestep/ui/gradio/events/wiring/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Typed wiring context and shared component-list builders.

The event facade contains several long output/input lists that must remain in
strict order. This module centralizes those list contracts to reduce copy-paste
risk while keeping handler wiring behavior unchanged.
"""

from dataclasses import dataclass
from typing import Any, Mapping


ComponentMap = Mapping[str, Any]

_AUTO_CHECKBOX_OUTPUT_KEYS = (
"bpm_auto",
"key_auto",
"timesig_auto",
"vocal_lang_auto",
"duration_auto",
"bpm",
"key_scale",
"time_signature",
"vocal_language",
"audio_duration",
)

_AUTO_CHECKBOX_INPUT_KEYS = (
"bpm",
"key_scale",
"time_signature",
"vocal_language",
"audio_duration",
)

_MODE_UI_OUTPUT_KEYS = (
"simple_mode_group",
"custom_mode_group",
"generate_btn",
"simple_sample_created",
"optional_params_accordion",
"task_type",
"src_audio_row",
"repainting_group",
"text2music_audio_codes_group",
"track_name",
"complete_track_classes",
"generate_btn_row",
"generation_mode",
"results_wrapper",
"think_checkbox",
"load_file_col",
"load_file",
"audio_cover_strength",
"cover_noise_strength",
"captions",
"lyrics",
"bpm",
"key_scale",
"time_signature",
"vocal_language",
"audio_duration",
"auto_score",
"autogen_checkbox",
"auto_lrc",
"analyze_btn",
"repainting_header_html",
"repainting_start",
"repainting_end",
"previous_generation_mode",
"remix_help_group",
"extract_help_group",
"complete_help_group",
"bpm_auto",
"key_auto",
"timesig_auto",
"vocal_lang_auto",
"duration_auto",
"text2music_audio_code_string",
"src_audio",
)


@dataclass(frozen=True)
class GenerationWiringContext:
"""Inputs required for generation/results event wiring."""

demo: Any
dit_handler: Any
llm_handler: Any
dataset_handler: Any
dataset_section: ComponentMap
generation_section: ComponentMap
results_section: ComponentMap


@dataclass(frozen=True)
class TrainingWiringContext:
"""Inputs required for training event wiring."""

demo: Any
dit_handler: Any
llm_handler: Any
training_section: ComponentMap


def build_auto_checkbox_outputs(context: GenerationWiringContext) -> list[Any]:
"""Return ordered auto-checkbox outputs for metadata field sync."""

generation = context.generation_section
return [generation[key] for key in _AUTO_CHECKBOX_OUTPUT_KEYS]


def build_auto_checkbox_inputs(context: GenerationWiringContext) -> list[Any]:
"""Return ordered metadata fields used to derive auto-checkbox state."""

generation = context.generation_section
return [generation[key] for key in _AUTO_CHECKBOX_INPUT_KEYS]


def build_mode_ui_outputs(context: GenerationWiringContext) -> list[Any]:
"""Return ordered mode-UI outputs shared across mode/remix/repaint wiring."""

generation = context.generation_section
return [generation[key] for key in _MODE_UI_OUTPUT_KEYS]
Loading