diff --git a/acestep/ui/gradio/events/__init__.py b/acestep/ui/gradio/events/__init__.py index 9bc9da42..f5ddf83e 100644 --- a/acestep/ui/gradio/events/__init__.py +++ b/acestep/ui/gradio/events/__init__.py @@ -13,9 +13,9 @@ from .wiring import ( GenerationWiringContext, TrainingWiringContext, - build_auto_checkbox_inputs, - build_auto_checkbox_outputs, build_mode_ui_outputs, + register_generation_metadata_handlers, + register_generation_service_handlers, ) from acestep.ui.gradio.i18n import t @@ -32,385 +32,16 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase results_section=results_section, ) - # ========== Dataset Handlers ========== - dataset_section["import_dataset_btn"].click( - fn=dataset_handler.import_dataset, - inputs=[dataset_section["dataset_type"]], - outputs=[dataset_section["data_status"]] + auto_checkbox_inputs, auto_checkbox_outputs = register_generation_service_handlers( + wiring_context ) - - # ========== Service Initialization ========== - generation_section["refresh_btn"].click( - fn=lambda: gen_h.refresh_checkpoints(dit_handler), - outputs=[generation_section["checkpoint_dropdown"]] - ) - - generation_section["config_path"].change( - fn=gen_h.update_model_type_settings, - inputs=[generation_section["config_path"], generation_section["generation_mode"]], - outputs=[ - generation_section["inference_steps"], - generation_section["guidance_scale"], - generation_section["use_adg"], - generation_section["shift"], - generation_section["cfg_interval_start"], - generation_section["cfg_interval_end"], - generation_section["task_type"], - generation_section["generation_mode"], - generation_section["init_llm_checkbox"], - ] - ) - - # ========== Tier Override ========== - generation_section["tier_dropdown"].change( - fn=lambda tier: gen_h.on_tier_change(tier, llm_handler), - inputs=[generation_section["tier_dropdown"]], - outputs=[ - generation_section["offload_to_cpu_checkbox"], - generation_section["offload_dit_to_cpu_checkbox"], - generation_section["compile_model_checkbox"], - generation_section["quantization_checkbox"], - generation_section["backend_dropdown"], - generation_section["lm_model_path"], - generation_section["init_llm_checkbox"], - generation_section["batch_size_input"], - generation_section["audio_duration"], - generation_section["gpu_info_display"], - ] - ) - - generation_section["init_btn"].click( - fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args), - inputs=[ - generation_section["checkpoint_dropdown"], - generation_section["config_path"], - generation_section["device"], - generation_section["init_llm_checkbox"], - generation_section["lm_model_path"], - generation_section["backend_dropdown"], - generation_section["use_flash_attention_checkbox"], - generation_section["offload_to_cpu_checkbox"], - generation_section["offload_dit_to_cpu_checkbox"], - generation_section["compile_model_checkbox"], - generation_section["quantization_checkbox"], - generation_section["mlx_dit_checkbox"], - generation_section["generation_mode"], # preserve current mode across init - generation_section["batch_size_input"], # preserve current batch_size across init - ], - outputs=[ - generation_section["init_status"], - generation_section["generate_btn"], - generation_section["service_config_accordion"], - # Model type settings (updated based on actual loaded model) - generation_section["inference_steps"], - generation_section["guidance_scale"], - generation_section["use_adg"], - generation_section["shift"], - generation_section["cfg_interval_start"], - generation_section["cfg_interval_end"], - generation_section["task_type"], - generation_section["generation_mode"], - generation_section["init_llm_checkbox"], - # GPU-config-aware limits (updated after initialization) - generation_section["audio_duration"], - generation_section["batch_size_input"], - # Think checkbox: enable if LLM initialized - generation_section["think_checkbox"], - ] - ) - - # ========== LoRA Handlers ========== - generation_section["load_lora_btn"].click( - fn=dit_handler.load_lora, - inputs=[generation_section["lora_path"]], - outputs=[generation_section["lora_status"]] - ).then( - # Update checkbox to enabled state after loading - fn=lambda: gr.update(value=True), - outputs=[generation_section["use_lora_checkbox"]] - ) - - generation_section["unload_lora_btn"].click( - fn=dit_handler.unload_lora, - outputs=[generation_section["lora_status"]] - ).then( - # Update checkbox to disabled state after unloading - fn=lambda: gr.update(value=False), - outputs=[generation_section["use_lora_checkbox"]] - ) - - generation_section["use_lora_checkbox"].change( - fn=dit_handler.set_use_lora, - inputs=[generation_section["use_lora_checkbox"]], - outputs=[generation_section["lora_status"]] - ) - - generation_section["lora_scale_slider"].change( - fn=dit_handler.set_lora_scale, - inputs=[generation_section["lora_scale_slider"]], - outputs=[generation_section["lora_status"]] - ) - - # ========== Auto Checkbox Handlers ========== - _auto_field_map = { - "bpm_auto": ("bpm", "bpm"), - "key_auto": ("key_scale", "key_scale"), - "timesig_auto": ("time_signature", "time_signature"), - "vocal_lang_auto": ("vocal_language", "vocal_language"), - "duration_auto": ("audio_duration", "audio_duration"), - } - for auto_key, (field_name, comp_key) in _auto_field_map.items(): - generation_section[auto_key].change( - fn=lambda checked, fn=field_name: gen_h.on_auto_checkbox_change(checked, fn), - inputs=[generation_section[auto_key]], - outputs=[generation_section[comp_key]], - ) - - generation_section["reset_all_auto_btn"].click( - fn=gen_h.reset_all_auto, - outputs=[ - generation_section["bpm_auto"], - generation_section["key_auto"], - generation_section["timesig_auto"], - generation_section["vocal_lang_auto"], - generation_section["duration_auto"], - generation_section["bpm"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["vocal_language"], - generation_section["audio_duration"], - ], - ) - - # ========== UI Visibility Updates ========== - generation_section["init_llm_checkbox"].change( - fn=gen_h.update_negative_prompt_visibility, - inputs=[generation_section["init_llm_checkbox"]], - outputs=[generation_section["lm_negative_prompt"]] - ) - - generation_section["batch_size_input"].change( - fn=gen_h.update_audio_components_visibility, - inputs=[generation_section["batch_size_input"]], - outputs=[ - results_section["audio_col_1"], - results_section["audio_col_2"], - results_section["audio_col_3"], - results_section["audio_col_4"], - results_section["audio_row_5_8"], - results_section["audio_col_5"], - results_section["audio_col_6"], - results_section["audio_col_7"], - results_section["audio_col_8"], - ] - ) - - # Auto-checkbox outputs used by .then() chains after events that populate fields - _auto_checkbox_outputs = build_auto_checkbox_outputs(wiring_context) - _auto_checkbox_inputs = build_auto_checkbox_inputs(wiring_context) - _mode_ui_outputs = build_mode_ui_outputs(wiring_context) - - # ========== Audio Conversion (LM Codes Hints accordion in Custom mode) ========== - generation_section["convert_src_to_codes_btn"].click( - fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src), - inputs=[generation_section["lm_codes_audio_upload"]], - outputs=[generation_section["text2music_audio_code_string"]] - ) - - # ========== Analyze Source Audio (Remix/Repaint: convert to codes + transcribe) ========== - generation_section["analyze_btn"].click( - fn=lambda src, debug: gen_h.analyze_src_audio(dit_handler, llm_handler, src, debug), - inputs=[ - generation_section["src_audio"], - generation_section["constrained_decoding_debug"], - ], - outputs=[ - generation_section["text2music_audio_code_string"], - results_section["status_output"], - generation_section["captions"], - generation_section["lyrics"], - generation_section["bpm"], - generation_section["audio_duration"], - generation_section["key_scale"], - generation_section["vocal_language"], - generation_section["time_signature"], - results_section["is_format_caption_state"], - ] - ).then( - fn=gen_h.uncheck_auto_for_populated_fields, - inputs=_auto_checkbox_inputs, - outputs=_auto_checkbox_outputs, + mode_ui_outputs = build_mode_ui_outputs(wiring_context) + register_generation_metadata_handlers( + wiring_context, + auto_checkbox_inputs=auto_checkbox_inputs, + auto_checkbox_outputs=auto_checkbox_outputs, ) - - # ========== Instruction UI Updates ========== - # Visibility of track_name / complete_track_classes / repainting_group is - # handled by compute_mode_ui_updates; this handler only refreshes the - # instruction text when relevant inputs change. - for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"], generation_section["reference_audio"]]: - trigger.change( - fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args), - inputs=[ - generation_section["task_type"], - generation_section["track_name"], - generation_section["complete_track_classes"], - generation_section["init_llm_checkbox"], - generation_section["reference_audio"], - ], - outputs=[ - generation_section["instruction_display_gen"], - ] - ) - # Validate reference audio eagerly so users get immediate feedback on invalid files. - generation_section["reference_audio"].change( - fn=lambda reference_audio: gen_h.validate_uploaded_audio_file(reference_audio, "reference"), - inputs=[generation_section["reference_audio"]], - outputs=[generation_section["reference_audio"]], - ) - - # ========== Sample/Transcribe Handlers ========== - # Load random example from ./examples/text2music directory - generation_section["sample_btn"].click( - fn=lambda task: gen_h.load_random_example(task, llm_handler) + (True,), - inputs=[ - generation_section["task_type"], - ], - outputs=[ - generation_section["captions"], - generation_section["lyrics"], - generation_section["think_checkbox"], - generation_section["bpm"], - generation_section["audio_duration"], - generation_section["key_scale"], - generation_section["vocal_language"], - generation_section["time_signature"], - results_section["is_format_caption_state"] - ] - ).then( - fn=gen_h.uncheck_auto_for_populated_fields, - inputs=_auto_checkbox_inputs, - outputs=_auto_checkbox_outputs, - ) - - generation_section["text2music_audio_code_string"].change( - fn=gen_h.update_transcribe_button_text, - inputs=[generation_section["text2music_audio_code_string"]], - outputs=[generation_section["transcribe_btn"]] - ) - - generation_section["transcribe_btn"].click( - fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug), - inputs=[ - generation_section["text2music_audio_code_string"], - generation_section["constrained_decoding_debug"] - ], - outputs=[ - results_section["status_output"], - generation_section["captions"], - generation_section["lyrics"], - generation_section["bpm"], - generation_section["audio_duration"], - generation_section["key_scale"], - generation_section["vocal_language"], - generation_section["time_signature"], - results_section["is_format_caption_state"] - ] - ).then( - fn=gen_h.uncheck_auto_for_populated_fields, - inputs=_auto_checkbox_inputs, - outputs=_auto_checkbox_outputs, - ) - - # ========== Reset Format Caption Flag ========== - for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"], - generation_section["key_scale"], generation_section["time_signature"], - generation_section["vocal_language"], generation_section["audio_duration"]]: - trigger.change( - fn=gen_h.reset_format_caption_flag, - inputs=[], - outputs=[results_section["is_format_caption_state"]] - ) - - # ========== Instrumental Checkbox ========== - generation_section["instrumental_checkbox"].change( - fn=gen_h.handle_instrumental_checkbox, - inputs=[ - generation_section["instrumental_checkbox"], - generation_section["lyrics"], - generation_section["lyrics_before_instrumental"], - ], - outputs=[ - generation_section["lyrics"], - generation_section["lyrics_before_instrumental"], - ] - ) - - # ========== Format Caption Button ========== - generation_section["format_caption_btn"].click( - fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_caption( - llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug - ), - inputs=[ - generation_section["captions"], - generation_section["lyrics"], - generation_section["bpm"], - generation_section["audio_duration"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["lm_temperature"], - generation_section["lm_top_k"], - generation_section["lm_top_p"], - generation_section["constrained_decoding_debug"], - ], - outputs=[ - generation_section["captions"], - generation_section["bpm"], - generation_section["audio_duration"], - generation_section["key_scale"], - generation_section["vocal_language"], - generation_section["time_signature"], - results_section["is_format_caption_state"], - results_section["status_output"], - ] - ).then( - fn=gen_h.uncheck_auto_for_populated_fields, - inputs=_auto_checkbox_inputs, - outputs=_auto_checkbox_outputs, - ) - - # ========== Format Lyrics Button ========== - generation_section["format_lyrics_btn"].click( - fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_lyrics( - llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug - ), - inputs=[ - generation_section["captions"], - generation_section["lyrics"], - generation_section["bpm"], - generation_section["audio_duration"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["lm_temperature"], - generation_section["lm_top_k"], - generation_section["lm_top_p"], - generation_section["constrained_decoding_debug"], - ], - outputs=[ - generation_section["lyrics"], - generation_section["bpm"], - generation_section["audio_duration"], - generation_section["key_scale"], - generation_section["vocal_language"], - generation_section["time_signature"], - results_section["is_format_caption_state"], - results_section["status_output"], - ] - ).then( - fn=gen_h.uncheck_auto_for_populated_fields, - inputs=_auto_checkbox_inputs, - outputs=_auto_checkbox_outputs, - ) - # ========== Generation Mode Change ========== generation_section["generation_mode"].change( fn=lambda mode, prev: gen_h.handle_generation_mode_change(mode, prev, llm_handler), @@ -418,7 +49,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase generation_section["generation_mode"], generation_section["previous_generation_mode"], ], - outputs=_mode_ui_outputs + outputs=mode_ui_outputs ) # ========== Extract Mode: Auto-fill caption from track_name ========== @@ -501,8 +132,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase ] ).then( fn=gen_h.uncheck_auto_for_populated_fields, - inputs=_auto_checkbox_inputs, - outputs=_auto_checkbox_outputs, + inputs=auto_checkbox_inputs, + outputs=auto_checkbox_outputs, ) # ========== Load/Save Metadata ========== @@ -551,8 +182,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase ] ).then( fn=gen_h.uncheck_auto_for_populated_fields, - inputs=_auto_checkbox_inputs, - outputs=_auto_checkbox_outputs, + inputs=auto_checkbox_inputs, + outputs=auto_checkbox_outputs, ) # Save buttons for all 8 audio outputs @@ -652,7 +283,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase generation_section["generation_mode"], generation_section["lyrics"], generation_section["captions"], - ] + _mode_ui_outputs, + ] + mode_ui_outputs, ) results_section[f"send_to_repaint_btn_{btn_idx}"].click( fn=lambda audio, lm, ly, cap, cur_mode: res_h.send_audio_to_repaint( @@ -669,7 +300,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase generation_section["generation_mode"], generation_section["lyrics"], generation_section["captions"], - ] + _mode_ui_outputs, + ] + mode_ui_outputs, ) # ========== Score Calculation Handlers ========== diff --git a/acestep/ui/gradio/events/wiring/__init__.py b/acestep/ui/gradio/events/wiring/__init__.py index af28ef6a..fc944ff6 100644 --- a/acestep/ui/gradio/events/wiring/__init__.py +++ b/acestep/ui/gradio/events/wiring/__init__.py @@ -11,6 +11,8 @@ build_auto_checkbox_outputs, build_mode_ui_outputs, ) +from .generation_metadata_wiring import register_generation_metadata_handlers +from .generation_service_wiring import register_generation_service_handlers __all__ = [ "GenerationWiringContext", @@ -18,4 +20,6 @@ "build_auto_checkbox_inputs", "build_auto_checkbox_outputs", "build_mode_ui_outputs", + "register_generation_metadata_handlers", + "register_generation_service_handlers", ] diff --git a/acestep/ui/gradio/events/wiring/decomposition_contract_test.py b/acestep/ui/gradio/events/wiring/decomposition_contract_test.py new file mode 100644 index 00000000..b6476419 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/decomposition_contract_test.py @@ -0,0 +1,78 @@ +"""Regression tests for PR2 wiring decomposition contracts. + +These tests validate source-level delegation in +``acestep.ui.gradio.events.__init__`` without importing Gradio dependencies. +""" + +import ast +from pathlib import Path +import unittest + + +_EVENTS_INIT_PATH = Path(__file__).resolve().parents[1] / "__init__.py" + + +def _load_setup_event_handlers_node() -> ast.FunctionDef: + """Return the AST node for ``setup_event_handlers``.""" + + source = _EVENTS_INIT_PATH.read_text(encoding="utf-8") + module = ast.parse(source) + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == "setup_event_handlers": + return node + raise AssertionError("setup_event_handlers not found") + + +def _call_name(node: ast.AST) -> str | None: + """Extract a simple function name from a call node target.""" + + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + return None + + +class DecompositionContractTests(unittest.TestCase): + """Verify delegation contracts introduced in PR2 wiring extraction.""" + + def test_setup_event_handlers_uses_generation_wiring_helpers(self): + """setup_event_handlers should delegate service/metadata registration.""" + + setup_node = _load_setup_event_handlers_node() + call_names = [] + for node in ast.walk(setup_node): + if isinstance(node, ast.Call): + name = _call_name(node.func) + if name: + call_names.append(name) + + self.assertIn("register_generation_service_handlers", call_names) + self.assertIn("register_generation_metadata_handlers", call_names) + self.assertIn("build_mode_ui_outputs", call_names) + + def test_generation_mode_change_uses_mode_ui_outputs_variable(self): + """generation_mode change handler should still output mode_ui_outputs.""" + + setup_node = _load_setup_event_handlers_node() + found_mode_change_output_binding = False + + for node in ast.walk(setup_node): + if not isinstance(node, ast.Call): + continue + if not isinstance(node.func, ast.Attribute) or node.func.attr != "change": + continue + for keyword in node.keywords: + if keyword.arg != "outputs": + continue + if isinstance(keyword.value, ast.Name) and keyword.value.id == "mode_ui_outputs": + found_mode_change_output_binding = True + break + if found_mode_change_output_binding: + break + + self.assertTrue(found_mode_change_output_binding) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/ui/gradio/events/wiring/generation_metadata_wiring.py b/acestep/ui/gradio/events/wiring/generation_metadata_wiring.py new file mode 100644 index 00000000..305e2919 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/generation_metadata_wiring.py @@ -0,0 +1,168 @@ +"""Generation metadata/text event wiring helpers. + +This module contains wiring for source analysis, transcribe/sample operations, +and caption/lyrics formatting flows. +""" + +from typing import Any, Sequence + +from .. import generation_handlers as gen_h +from .context import GenerationWiringContext +from .generation_text_format_wiring import register_generation_text_format_handlers + + +def register_generation_metadata_handlers( + context: GenerationWiringContext, + auto_checkbox_inputs: Sequence[Any], + auto_checkbox_outputs: Sequence[Any], +) -> None: + """Register metadata and text-format generation handlers.""" + + generation_section = context.generation_section + results_section = context.results_section + dit_handler = context.dit_handler + llm_handler = context.llm_handler + + # ========== Audio Conversion (LM Codes Hints accordion in Custom mode) ========== + generation_section["convert_src_to_codes_btn"].click( + fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src), + inputs=[generation_section["lm_codes_audio_upload"]], + outputs=[generation_section["text2music_audio_code_string"]], + ) + + # ========== Analyze Source Audio (Remix/Repaint: convert to codes + transcribe) ========== + generation_section["analyze_btn"].click( + fn=lambda src, debug: gen_h.analyze_src_audio(dit_handler, llm_handler, src, debug), + inputs=[ + generation_section["src_audio"], + generation_section["constrained_decoding_debug"], + ], + outputs=[ + generation_section["text2music_audio_code_string"], + results_section["status_output"], + generation_section["captions"], + generation_section["lyrics"], + generation_section["bpm"], + generation_section["audio_duration"], + generation_section["key_scale"], + generation_section["vocal_language"], + generation_section["time_signature"], + results_section["is_format_caption_state"], + ], + ).then( + fn=gen_h.uncheck_auto_for_populated_fields, + inputs=list(auto_checkbox_inputs), + outputs=list(auto_checkbox_outputs), + ) + + # ========== Instruction UI Updates ========== + for trigger in [ + generation_section["task_type"], + generation_section["track_name"], + generation_section["complete_track_classes"], + generation_section["reference_audio"], + ]: + trigger.change( + fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args), + inputs=[ + generation_section["task_type"], + generation_section["track_name"], + generation_section["complete_track_classes"], + generation_section["init_llm_checkbox"], + generation_section["reference_audio"], + ], + outputs=[generation_section["instruction_display_gen"]], + ) + + # Validate reference audio eagerly so users get immediate feedback on invalid files. + generation_section["reference_audio"].change( + fn=lambda reference_audio: gen_h.validate_uploaded_audio_file(reference_audio, "reference"), + inputs=[generation_section["reference_audio"]], + outputs=[generation_section["reference_audio"]], + ) + + # ========== Sample/Transcribe Handlers ========== + generation_section["sample_btn"].click( + fn=lambda task: gen_h.load_random_example(task, llm_handler) + (True,), + inputs=[generation_section["task_type"]], + outputs=[ + generation_section["captions"], + generation_section["lyrics"], + generation_section["think_checkbox"], + generation_section["bpm"], + generation_section["audio_duration"], + generation_section["key_scale"], + generation_section["vocal_language"], + generation_section["time_signature"], + results_section["is_format_caption_state"], + ], + ).then( + fn=gen_h.uncheck_auto_for_populated_fields, + inputs=list(auto_checkbox_inputs), + outputs=list(auto_checkbox_outputs), + ) + + generation_section["text2music_audio_code_string"].change( + fn=gen_h.update_transcribe_button_text, + inputs=[generation_section["text2music_audio_code_string"]], + outputs=[generation_section["transcribe_btn"]], + ) + + generation_section["transcribe_btn"].click( + fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug), + inputs=[ + generation_section["text2music_audio_code_string"], + generation_section["constrained_decoding_debug"], + ], + outputs=[ + results_section["status_output"], + generation_section["captions"], + generation_section["lyrics"], + generation_section["bpm"], + generation_section["audio_duration"], + generation_section["key_scale"], + generation_section["vocal_language"], + generation_section["time_signature"], + results_section["is_format_caption_state"], + ], + ).then( + fn=gen_h.uncheck_auto_for_populated_fields, + inputs=list(auto_checkbox_inputs), + outputs=list(auto_checkbox_outputs), + ) + + # ========== Reset Format Caption Flag ========== + for trigger in [ + generation_section["captions"], + generation_section["lyrics"], + generation_section["bpm"], + generation_section["key_scale"], + generation_section["time_signature"], + generation_section["vocal_language"], + generation_section["audio_duration"], + ]: + trigger.change( + fn=gen_h.reset_format_caption_flag, + inputs=[], + outputs=[results_section["is_format_caption_state"]], + ) + + # ========== Instrumental Checkbox ========== + generation_section["instrumental_checkbox"].change( + fn=gen_h.handle_instrumental_checkbox, + inputs=[ + generation_section["instrumental_checkbox"], + generation_section["lyrics"], + generation_section["lyrics_before_instrumental"], + ], + outputs=[ + generation_section["lyrics"], + generation_section["lyrics_before_instrumental"], + ], + ) + + register_generation_text_format_handlers( + context, + auto_checkbox_inputs=auto_checkbox_inputs, + auto_checkbox_outputs=auto_checkbox_outputs, + ) diff --git a/acestep/ui/gradio/events/wiring/generation_service_wiring.py b/acestep/ui/gradio/events/wiring/generation_service_wiring.py new file mode 100644 index 00000000..ded273f9 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/generation_service_wiring.py @@ -0,0 +1,191 @@ +"""Generation service-layer event wiring helpers. + +This module contains wiring related to service initialization, LoRA controls, +auto-checkbox controls, and visibility updates for generation components. +""" + +from typing import Any + +import gradio as gr + +from .. import generation_handlers as gen_h +from .context import ( + GenerationWiringContext, + build_auto_checkbox_inputs, + build_auto_checkbox_outputs, +) + + +def register_generation_service_handlers( + context: GenerationWiringContext, +) -> tuple[list[Any], list[Any]]: + """Register generation service/init handlers and return auto-checkbox lists.""" + + dataset_section = context.dataset_section + generation_section = context.generation_section + results_section = context.results_section + dit_handler = context.dit_handler + llm_handler = context.llm_handler + dataset_handler = context.dataset_handler + + # ========== Dataset Handlers ========== + dataset_section["import_dataset_btn"].click( + fn=dataset_handler.import_dataset, + inputs=[dataset_section["dataset_type"]], + outputs=[dataset_section["data_status"]], + ) + + # ========== Service Initialization ========== + generation_section["refresh_btn"].click( + fn=lambda: gen_h.refresh_checkpoints(dit_handler), + outputs=[generation_section["checkpoint_dropdown"]], + ) + + generation_section["config_path"].change( + fn=gen_h.update_model_type_settings, + inputs=[generation_section["config_path"], generation_section["generation_mode"]], + outputs=[ + generation_section["inference_steps"], + generation_section["guidance_scale"], + generation_section["use_adg"], + generation_section["shift"], + generation_section["cfg_interval_start"], + generation_section["cfg_interval_end"], + generation_section["task_type"], + generation_section["generation_mode"], + generation_section["init_llm_checkbox"], + ], + ) + + # ========== Tier Override ========== + generation_section["tier_dropdown"].change( + fn=lambda tier: gen_h.on_tier_change(tier, llm_handler), + inputs=[generation_section["tier_dropdown"]], + outputs=[ + generation_section["offload_to_cpu_checkbox"], + generation_section["offload_dit_to_cpu_checkbox"], + generation_section["compile_model_checkbox"], + generation_section["quantization_checkbox"], + generation_section["backend_dropdown"], + generation_section["lm_model_path"], + generation_section["init_llm_checkbox"], + generation_section["batch_size_input"], + generation_section["audio_duration"], + generation_section["gpu_info_display"], + ], + ) + + generation_section["init_btn"].click( + fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args), + inputs=[ + generation_section["checkpoint_dropdown"], + generation_section["config_path"], + generation_section["device"], + generation_section["init_llm_checkbox"], + generation_section["lm_model_path"], + generation_section["backend_dropdown"], + generation_section["use_flash_attention_checkbox"], + generation_section["offload_to_cpu_checkbox"], + generation_section["offload_dit_to_cpu_checkbox"], + generation_section["compile_model_checkbox"], + generation_section["quantization_checkbox"], + generation_section["mlx_dit_checkbox"], + generation_section["generation_mode"], + generation_section["batch_size_input"], + ], + outputs=[ + generation_section["init_status"], + generation_section["generate_btn"], + generation_section["service_config_accordion"], + generation_section["inference_steps"], + generation_section["guidance_scale"], + generation_section["use_adg"], + generation_section["shift"], + generation_section["cfg_interval_start"], + generation_section["cfg_interval_end"], + generation_section["task_type"], + generation_section["generation_mode"], + generation_section["init_llm_checkbox"], + generation_section["audio_duration"], + generation_section["batch_size_input"], + generation_section["think_checkbox"], + ], + ) + + # ========== LoRA Handlers ========== + generation_section["load_lora_btn"].click( + fn=dit_handler.load_lora, + inputs=[generation_section["lora_path"]], + outputs=[generation_section["lora_status"]], + ).then( + fn=lambda: gr.update(value=True), + outputs=[generation_section["use_lora_checkbox"]], + ) + + generation_section["unload_lora_btn"].click( + fn=dit_handler.unload_lora, + outputs=[generation_section["lora_status"]], + ).then( + fn=lambda: gr.update(value=False), + outputs=[generation_section["use_lora_checkbox"]], + ) + + generation_section["use_lora_checkbox"].change( + fn=dit_handler.set_use_lora, + inputs=[generation_section["use_lora_checkbox"]], + outputs=[generation_section["lora_status"]], + ) + + generation_section["lora_scale_slider"].change( + fn=dit_handler.set_lora_scale, + inputs=[generation_section["lora_scale_slider"]], + outputs=[generation_section["lora_status"]], + ) + + # ========== Auto Checkbox Handlers ========== + auto_field_map = { + "bpm_auto": ("bpm", "bpm"), + "key_auto": ("key_scale", "key_scale"), + "timesig_auto": ("time_signature", "time_signature"), + "vocal_lang_auto": ("vocal_language", "vocal_language"), + "duration_auto": ("audio_duration", "audio_duration"), + } + for auto_key, (field_name, comp_key) in auto_field_map.items(): + generation_section[auto_key].change( + fn=lambda checked, fn=field_name: gen_h.on_auto_checkbox_change(checked, fn), + inputs=[generation_section[auto_key]], + outputs=[generation_section[comp_key]], + ) + + auto_checkbox_outputs = build_auto_checkbox_outputs(context) + auto_checkbox_inputs = build_auto_checkbox_inputs(context) + + generation_section["reset_all_auto_btn"].click( + fn=gen_h.reset_all_auto, + outputs=auto_checkbox_outputs, + ) + + # ========== UI Visibility Updates ========== + generation_section["init_llm_checkbox"].change( + fn=gen_h.update_negative_prompt_visibility, + inputs=[generation_section["init_llm_checkbox"]], + outputs=[generation_section["lm_negative_prompt"]], + ) + + generation_section["batch_size_input"].change( + fn=gen_h.update_audio_components_visibility, + inputs=[generation_section["batch_size_input"]], + outputs=[ + results_section["audio_col_1"], + results_section["audio_col_2"], + results_section["audio_col_3"], + results_section["audio_col_4"], + results_section["audio_row_5_8"], + results_section["audio_col_5"], + results_section["audio_col_6"], + results_section["audio_col_7"], + results_section["audio_col_8"], + ], + ) + + return auto_checkbox_inputs, auto_checkbox_outputs diff --git a/acestep/ui/gradio/events/wiring/generation_text_format_wiring.py b/acestep/ui/gradio/events/wiring/generation_text_format_wiring.py new file mode 100644 index 00000000..d788839d --- /dev/null +++ b/acestep/ui/gradio/events/wiring/generation_text_format_wiring.py @@ -0,0 +1,87 @@ +"""Generation text-format event wiring helpers. + +This module isolates caption/lyrics formatting event registration. +""" + +from typing import Any, Sequence + +from .. import generation_handlers as gen_h +from .context import GenerationWiringContext + + +def register_generation_text_format_handlers( + context: GenerationWiringContext, + auto_checkbox_inputs: Sequence[Any], + auto_checkbox_outputs: Sequence[Any], +) -> None: + """Register caption/lyrics format handlers and their auto-checkbox sync.""" + + generation_section = context.generation_section + results_section = context.results_section + llm_handler = context.llm_handler + + # ========== Format Caption Button ========== + generation_section["format_caption_btn"].click( + fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_caption( + llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug + ), + inputs=[ + generation_section["captions"], + generation_section["lyrics"], + generation_section["bpm"], + generation_section["audio_duration"], + generation_section["key_scale"], + generation_section["time_signature"], + generation_section["lm_temperature"], + generation_section["lm_top_k"], + generation_section["lm_top_p"], + generation_section["constrained_decoding_debug"], + ], + outputs=[ + generation_section["captions"], + generation_section["bpm"], + generation_section["audio_duration"], + generation_section["key_scale"], + generation_section["vocal_language"], + generation_section["time_signature"], + results_section["is_format_caption_state"], + results_section["status_output"], + ], + ).then( + fn=gen_h.uncheck_auto_for_populated_fields, + inputs=list(auto_checkbox_inputs), + outputs=list(auto_checkbox_outputs), + ) + + # ========== Format Lyrics Button ========== + generation_section["format_lyrics_btn"].click( + fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_lyrics( + llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug + ), + inputs=[ + generation_section["captions"], + generation_section["lyrics"], + generation_section["bpm"], + generation_section["audio_duration"], + generation_section["key_scale"], + generation_section["time_signature"], + generation_section["lm_temperature"], + generation_section["lm_top_k"], + generation_section["lm_top_p"], + generation_section["constrained_decoding_debug"], + ], + outputs=[ + generation_section["lyrics"], + generation_section["bpm"], + generation_section["audio_duration"], + generation_section["key_scale"], + generation_section["vocal_language"], + generation_section["time_signature"], + results_section["is_format_caption_state"], + results_section["status_output"], + ], + ).then( + fn=gen_h.uncheck_auto_for_populated_fields, + inputs=list(auto_checkbox_inputs), + outputs=list(auto_checkbox_outputs), + )