diff --git a/acestep/ui/gradio/events/__init__.py b/acestep/ui/gradio/events/__init__.py index 177b4b58..9bc9da42 100644 --- a/acestep/ui/gradio/events/__init__.py +++ b/acestep/ui/gradio/events/__init__.py @@ -10,11 +10,27 @@ from . import generation_handlers as gen_h from . import results_handlers as res_h from . import training_handlers as train_h +from .wiring import ( + GenerationWiringContext, + TrainingWiringContext, + build_auto_checkbox_inputs, + build_auto_checkbox_outputs, + build_mode_ui_outputs, +) from acestep.ui.gradio.i18n import t def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section): """Setup event handlers connecting UI components and business logic""" + wiring_context = GenerationWiringContext( + demo=demo, + dit_handler=dit_handler, + llm_handler=llm_handler, + dataset_handler=dataset_handler, + dataset_section=dataset_section, + generation_section=generation_section, + results_section=results_section, + ) # ========== Dataset Handlers ========== dataset_section["import_dataset_btn"].click( @@ -190,25 +206,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase ) # Auto-checkbox outputs used by .then() chains after events that populate fields - _auto_checkbox_outputs = [ - generation_section["bpm_auto"], - generation_section["key_auto"], - generation_section["timesig_auto"], - generation_section["vocal_lang_auto"], - generation_section["duration_auto"], - generation_section["bpm"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["vocal_language"], - generation_section["audio_duration"], - ] - _auto_checkbox_inputs = [ - generation_section["bpm"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["vocal_language"], - generation_section["audio_duration"], - ] + _auto_checkbox_outputs = build_auto_checkbox_outputs(wiring_context) + _auto_checkbox_inputs = build_auto_checkbox_inputs(wiring_context) + _mode_ui_outputs = build_mode_ui_outputs(wiring_context) # ========== Audio Conversion (LM Codes Hints accordion in Custom mode) ========== generation_section["convert_src_to_codes_btn"].click( @@ -418,58 +418,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase generation_section["generation_mode"], generation_section["previous_generation_mode"], ], - outputs=[ - generation_section["simple_mode_group"], - generation_section["custom_mode_group"], - generation_section["generate_btn"], - generation_section["simple_sample_created"], - generation_section["optional_params_accordion"], - generation_section["task_type"], - generation_section["src_audio_row"], - generation_section["repainting_group"], - generation_section["text2music_audio_codes_group"], - generation_section["track_name"], - generation_section["complete_track_classes"], - generation_section["generate_btn_row"], - generation_section["generation_mode"], - generation_section["results_wrapper"], - generation_section["think_checkbox"], - generation_section["load_file_col"], - generation_section["load_file"], - generation_section["audio_cover_strength"], - generation_section["cover_noise_strength"], - # Extract/Lego-mode outputs (indices 19-29) - generation_section["captions"], - generation_section["lyrics"], - generation_section["bpm"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["vocal_language"], - generation_section["audio_duration"], - generation_section["auto_score"], - generation_section["autogen_checkbox"], - generation_section["auto_lrc"], - generation_section["analyze_btn"], - # Dynamic repainting/stem labels (indices 30-32) - generation_section["repainting_header_html"], - generation_section["repainting_start"], - generation_section["repainting_end"], - # Previous mode state (index 33) - generation_section["previous_generation_mode"], - # Mode-specific help button groups (indices 34-36) - generation_section["remix_help_group"], - generation_section["extract_help_group"], - generation_section["complete_help_group"], - # Auto checkbox updates (indices 37-41) - generation_section["bpm_auto"], - generation_section["key_auto"], - generation_section["timesig_auto"], - generation_section["vocal_lang_auto"], - generation_section["duration_auto"], - # State-leakage fix: clear stale values on mode switch (indices 42-43) - generation_section["text2music_audio_code_string"], - generation_section["src_audio"], - ] + outputs=_mode_ui_outputs ) # ========== Extract Mode: Auto-fill caption from track_name ========== @@ -687,58 +636,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase # ========== Send to Remix / Repaint Handlers ========== # Mode-UI outputs shared with generation_mode.change — applied atomically # so we don't rely on a chained .change() event for visibility/label updates. - _mode_ui_outputs = [ - generation_section["simple_mode_group"], - generation_section["custom_mode_group"], - generation_section["generate_btn"], - generation_section["simple_sample_created"], - generation_section["optional_params_accordion"], - generation_section["task_type"], - generation_section["src_audio_row"], - generation_section["repainting_group"], - generation_section["text2music_audio_codes_group"], - generation_section["track_name"], - generation_section["complete_track_classes"], - generation_section["generate_btn_row"], - generation_section["generation_mode"], - generation_section["results_wrapper"], - generation_section["think_checkbox"], - generation_section["load_file_col"], - generation_section["load_file"], - generation_section["audio_cover_strength"], - generation_section["cover_noise_strength"], - # Extract/Lego-mode outputs (indices 19-29) - generation_section["captions"], - generation_section["lyrics"], - generation_section["bpm"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["vocal_language"], - generation_section["audio_duration"], - generation_section["auto_score"], - generation_section["autogen_checkbox"], - generation_section["auto_lrc"], - generation_section["analyze_btn"], - # Dynamic repainting/stem labels (indices 30-32) - generation_section["repainting_header_html"], - generation_section["repainting_start"], - generation_section["repainting_end"], - # Previous mode state (index 33) - generation_section["previous_generation_mode"], - # Mode-specific help button groups (indices 34-36) - generation_section["remix_help_group"], - generation_section["extract_help_group"], - generation_section["complete_help_group"], - # Auto checkbox updates (indices 37-41) - generation_section["bpm_auto"], - generation_section["key_auto"], - generation_section["timesig_auto"], - generation_section["vocal_lang_auto"], - generation_section["duration_auto"], - # State-leakage fix: clear stale values on mode switch (indices 42-43) - generation_section["text2music_audio_code_string"], - generation_section["src_audio"], - ] for btn_idx in range(1, 9): results_section[f"send_to_remix_btn_{btn_idx}"].click( fn=lambda audio, lm, ly, cap, cur_mode: res_h.send_audio_to_remix( @@ -778,6 +675,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase # ========== Score Calculation Handlers ========== # Use default argument to capture btn_idx value at definition time (Python closure fix) def make_score_handler(idx): + """Build a score-click callback bound to a fixed result slot index.""" return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection( dit_handler, llm_handler, idx, scale, batch_idx, queue ) @@ -800,6 +698,7 @@ def make_score_handler(idx): # ========== LRC Timestamp Handlers ========== # Use default argument to capture btn_idx value at definition time (Python closure fix) def make_lrc_handler(idx): + """Build an LRC-generation callback bound to a fixed result slot index.""" return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler( dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps ) @@ -842,6 +741,7 @@ def make_lrc_handler(idx): ) def generation_wrapper(*args): + """Proxy batched generation to the results handler stream.""" yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args) # ========== Generation Handler ========== generation_section["generate_btn"].click( @@ -1239,6 +1139,13 @@ def generation_wrapper(*args): def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section): """Setup event handlers for the training tab (dataset builder and LoRA training)""" + training_context = TrainingWiringContext( + demo=demo, + dit_handler=dit_handler, + llm_handler=llm_handler, + training_section=training_section, + ) + training_section = training_context.training_section # ========== Load Existing Dataset (Top Section) ========== @@ -1516,6 +1423,7 @@ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_secti # Start training from preprocessed tensors def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts): + """Stream LoRA training progress and normalize failure outputs for the UI.""" from loguru import logger if not isinstance(ts, dict): ts = {"is_training": False, "should_stop": False} @@ -1588,6 +1496,7 @@ def lokr_training_wrapper( tensor_dir, ldim, lalpha, factor, decompose_both, use_tucker, use_scalar, weight_decompose, lr, ep, bs, ga, se, sh, sd, od, ts, ): + """Stream LoKr training progress and normalize failure outputs for the UI.""" from loguru import logger if not isinstance(ts, dict): ts = {"is_training": False, "should_stop": False} diff --git a/acestep/ui/gradio/events/wiring/__init__.py b/acestep/ui/gradio/events/wiring/__init__.py new file mode 100644 index 00000000..af28ef6a --- /dev/null +++ b/acestep/ui/gradio/events/wiring/__init__.py @@ -0,0 +1,21 @@ +"""Wiring helpers for Gradio event registration. + +This package provides shared context and list-builder helpers used by the +event wiring facade in ``acestep.ui.gradio.events``. +""" + +from .context import ( + GenerationWiringContext, + TrainingWiringContext, + build_auto_checkbox_inputs, + build_auto_checkbox_outputs, + build_mode_ui_outputs, +) + +__all__ = [ + "GenerationWiringContext", + "TrainingWiringContext", + "build_auto_checkbox_inputs", + "build_auto_checkbox_outputs", + "build_mode_ui_outputs", +] diff --git a/acestep/ui/gradio/events/wiring/context.py b/acestep/ui/gradio/events/wiring/context.py new file mode 100644 index 00000000..b6c7b229 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/context.py @@ -0,0 +1,124 @@ +"""Typed wiring context and shared component-list builders. + +The event facade contains several long output/input lists that must remain in +strict order. This module centralizes those list contracts to reduce copy-paste +risk while keeping handler wiring behavior unchanged. +""" + +from dataclasses import dataclass +from typing import Any, Mapping + + +ComponentMap = Mapping[str, Any] + +_AUTO_CHECKBOX_OUTPUT_KEYS = ( + "bpm_auto", + "key_auto", + "timesig_auto", + "vocal_lang_auto", + "duration_auto", + "bpm", + "key_scale", + "time_signature", + "vocal_language", + "audio_duration", +) + +_AUTO_CHECKBOX_INPUT_KEYS = ( + "bpm", + "key_scale", + "time_signature", + "vocal_language", + "audio_duration", +) + +_MODE_UI_OUTPUT_KEYS = ( + "simple_mode_group", + "custom_mode_group", + "generate_btn", + "simple_sample_created", + "optional_params_accordion", + "task_type", + "src_audio_row", + "repainting_group", + "text2music_audio_codes_group", + "track_name", + "complete_track_classes", + "generate_btn_row", + "generation_mode", + "results_wrapper", + "think_checkbox", + "load_file_col", + "load_file", + "audio_cover_strength", + "cover_noise_strength", + "captions", + "lyrics", + "bpm", + "key_scale", + "time_signature", + "vocal_language", + "audio_duration", + "auto_score", + "autogen_checkbox", + "auto_lrc", + "analyze_btn", + "repainting_header_html", + "repainting_start", + "repainting_end", + "previous_generation_mode", + "remix_help_group", + "extract_help_group", + "complete_help_group", + "bpm_auto", + "key_auto", + "timesig_auto", + "vocal_lang_auto", + "duration_auto", + "text2music_audio_code_string", + "src_audio", +) + + +@dataclass(frozen=True) +class GenerationWiringContext: + """Inputs required for generation/results event wiring.""" + + demo: Any + dit_handler: Any + llm_handler: Any + dataset_handler: Any + dataset_section: ComponentMap + generation_section: ComponentMap + results_section: ComponentMap + + +@dataclass(frozen=True) +class TrainingWiringContext: + """Inputs required for training event wiring.""" + + demo: Any + dit_handler: Any + llm_handler: Any + training_section: ComponentMap + + +def build_auto_checkbox_outputs(context: GenerationWiringContext) -> list[Any]: + """Return ordered auto-checkbox outputs for metadata field sync.""" + + generation = context.generation_section + return [generation[key] for key in _AUTO_CHECKBOX_OUTPUT_KEYS] + + +def build_auto_checkbox_inputs(context: GenerationWiringContext) -> list[Any]: + """Return ordered metadata fields used to derive auto-checkbox state.""" + + generation = context.generation_section + return [generation[key] for key in _AUTO_CHECKBOX_INPUT_KEYS] + + +def build_mode_ui_outputs(context: GenerationWiringContext) -> list[Any]: + """Return ordered mode-UI outputs shared across mode/remix/repaint wiring.""" + + generation = context.generation_section + return [generation[key] for key in _MODE_UI_OUTPUT_KEYS] diff --git a/acestep/ui/gradio/events/wiring/context_test.py b/acestep/ui/gradio/events/wiring/context_test.py new file mode 100644 index 00000000..e8c9a619 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/context_test.py @@ -0,0 +1,142 @@ +"""Unit tests for event-wiring context helpers.""" + +import importlib.util +from pathlib import Path +import unittest + + +AUTO_OUTPUT_EXPECTED = [ + "bpm_auto", + "key_auto", + "timesig_auto", + "vocal_lang_auto", + "duration_auto", + "bpm", + "key_scale", + "time_signature", + "vocal_language", + "audio_duration", +] + +AUTO_INPUT_EXPECTED = [ + "bpm", + "key_scale", + "time_signature", + "vocal_language", + "audio_duration", +] + +MODE_OUTPUT_EXPECTED = [ + "simple_mode_group", + "custom_mode_group", + "generate_btn", + "simple_sample_created", + "optional_params_accordion", + "task_type", + "src_audio_row", + "repainting_group", + "text2music_audio_codes_group", + "track_name", + "complete_track_classes", + "generate_btn_row", + "generation_mode", + "results_wrapper", + "think_checkbox", + "load_file_col", + "load_file", + "audio_cover_strength", + "cover_noise_strength", + "captions", + "lyrics", + "bpm", + "key_scale", + "time_signature", + "vocal_language", + "audio_duration", + "auto_score", + "autogen_checkbox", + "auto_lrc", + "analyze_btn", + "repainting_header_html", + "repainting_start", + "repainting_end", + "previous_generation_mode", + "remix_help_group", + "extract_help_group", + "complete_help_group", + "bpm_auto", + "key_auto", + "timesig_auto", + "vocal_lang_auto", + "duration_auto", + "text2music_audio_code_string", + "src_audio", +] + + +def _load_context_module(): + """Load the context module directly from disk without package side effects.""" + module_path = Path(__file__).with_name("context.py") + spec = importlib.util.spec_from_file_location("events_wiring_context", module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load spec for {module_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +_MODULE = _load_context_module() +GenerationWiringContext = _MODULE.GenerationWiringContext +TrainingWiringContext = _MODULE.TrainingWiringContext +build_auto_checkbox_inputs = _MODULE.build_auto_checkbox_inputs +build_auto_checkbox_outputs = _MODULE.build_auto_checkbox_outputs +build_mode_ui_outputs = _MODULE.build_mode_ui_outputs + + +class GenerationWiringContextTests(unittest.TestCase): + """Verify ordered component-list builders used by event wiring.""" + + def setUp(self): + """Build a minimal context with deterministic component values.""" + required_keys = set(MODE_OUTPUT_EXPECTED + AUTO_OUTPUT_EXPECTED + AUTO_INPUT_EXPECTED) + self.generation_section = {key: key for key in required_keys} + self.context = GenerationWiringContext( + demo=object(), + dit_handler=object(), + llm_handler=object(), + dataset_handler=object(), + dataset_section={}, + generation_section=self.generation_section, + results_section={}, + ) + + def test_build_auto_checkbox_outputs_uses_expected_order(self): + """Auto checkbox output order must remain stable across refactors.""" + self.assertEqual(build_auto_checkbox_outputs(self.context), AUTO_OUTPUT_EXPECTED) + + def test_build_auto_checkbox_inputs_uses_expected_order(self): + """Auto checkbox input order must remain stable across refactors.""" + self.assertEqual(build_auto_checkbox_inputs(self.context), AUTO_INPUT_EXPECTED) + + def test_build_mode_ui_outputs_uses_expected_order(self): + """Mode output list must match wiring contract for event handlers.""" + self.assertEqual(build_mode_ui_outputs(self.context), MODE_OUTPUT_EXPECTED) + + +class TrainingWiringContextTests(unittest.TestCase): + """Verify the training context stores expected references.""" + + def test_training_context_keeps_training_section_reference(self): + """Training context should keep the original section mapping.""" + training_section = {"training_progress": "training_progress"} + context = TrainingWiringContext( + demo=object(), + dit_handler=object(), + llm_handler=object(), + training_section=training_section, + ) + self.assertIs(context.training_section, training_section) + + +if __name__ == "__main__": + unittest.main()