diff --git a/acestep/ui/gradio/events/__init__.py b/acestep/ui/gradio/events/__init__.py index 117467c9..f78bcdc3 100644 --- a/acestep/ui/gradio/events/__init__.py +++ b/acestep/ui/gradio/events/__init__.py @@ -2,26 +2,25 @@ Gradio UI Event Handlers Module Main entry point for setting up all event handlers """ -import gradio as gr -from typing import Optional -from loguru import logger - # Import handler modules -from . import generation_handlers as gen_h -from . import results_handlers as res_h -from . import training_handlers as train_h from .wiring import ( GenerationWiringContext, TrainingWiringContext, build_mode_ui_outputs, register_generation_batch_navigation_handlers, + register_generation_metadata_file_handlers, register_generation_metadata_handlers, register_generation_mode_handlers, register_generation_run_handlers, register_results_aux_handlers, + register_results_restore_and_lrc_handlers, + register_results_save_button_handlers, register_generation_service_handlers, + register_training_dataset_builder_handlers, + register_training_dataset_load_handler, + register_training_preprocess_handler, + register_training_run_handlers, ) -from acestep.ui.gradio.i18n import t def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section): @@ -81,195 +80,19 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase auto_checkbox_outputs=auto_checkbox_outputs, ) - # ========== Load/Save Metadata ========== - generation_section["load_file"].upload( - fn=lambda file_obj: gen_h.load_metadata(file_obj, llm_handler), - inputs=[generation_section["load_file"]], - outputs=[ - generation_section["task_type"], - generation_section["captions"], - generation_section["lyrics"], - generation_section["vocal_language"], - generation_section["bpm"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["audio_duration"], - generation_section["batch_size_input"], - generation_section["inference_steps"], - generation_section["guidance_scale"], - generation_section["seed"], - generation_section["random_seed_checkbox"], - generation_section["use_adg"], - generation_section["cfg_interval_start"], - generation_section["cfg_interval_end"], - generation_section["shift"], - generation_section["infer_method"], - generation_section["custom_timesteps"], - generation_section["audio_format"], - generation_section["lm_temperature"], - generation_section["lm_cfg_scale"], - generation_section["lm_top_k"], - generation_section["lm_top_p"], - generation_section["lm_negative_prompt"], - generation_section["use_cot_metas"], # Added: use_cot_metas - generation_section["use_cot_caption"], - generation_section["use_cot_language"], - generation_section["audio_cover_strength"], - generation_section["cover_noise_strength"], - generation_section["think_checkbox"], - generation_section["text2music_audio_code_string"], - generation_section["repainting_start"], - generation_section["repainting_end"], - generation_section["track_name"], - generation_section["complete_track_classes"], - generation_section["instrumental_checkbox"], # Added: instrumental_checkbox - results_section["is_format_caption_state"] - ] - ).then( - fn=gen_h.uncheck_auto_for_populated_fields, - inputs=auto_checkbox_inputs, - outputs=auto_checkbox_outputs, - ) - - # Save buttons for all 8 audio outputs - download_existing_js = """(current_audio, batch_files) => { - // Debug: print what the input actually is - console.log("šŸ‘‰ [Debug] Current Audio Input:", current_audio); - - // 1. Safety check - if (!current_audio) { - console.warn("āš ļø No audio selected or audio is empty."); - return; - } - if (!batch_files || !Array.isArray(batch_files)) { - console.warn("āš ļø Batch file list is empty/not ready."); - return; - } - - // 2. Smartly extract path string - let pathString = ""; - - if (typeof current_audio === "string") { - // Case A: direct path string received - pathString = current_audio; - } else if (typeof current_audio === "object") { - // Case B: an object is received, try common properties - // Gradio file objects usually have path, url, or name - pathString = current_audio.path || current_audio.name || current_audio.url || ""; - } - - if (!pathString) { - console.error("āŒ Error: Could not extract a valid path string from input.", current_audio); - return; - } - - // 3. Extract Key (UUID) - // Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3 - let filename = pathString.split(/[\\\\/]/).pop(); // get the filename - let key = filename.split('.')[0]; // get UUID without extension - - console.log(`šŸ”‘ Key extracted: ${key}`); - - // 4. Find matching file(s) in the list - let targets = batch_files.filter(f => { - // Also extract names from batch_files objects - // f usually contains name (backend path) and orig_name (download name) - const fPath = f.name || f.path || ""; - return fPath.includes(key); - }); - - if (targets.length === 0) { - console.warn("āŒ No matching files found in batch list for key:", key); - alert("Batch list does not contain this file yet. Please wait for generation to finish."); - return; - } - - // 5. Trigger download(s) - console.log(`šŸŽÆ Found ${targets.length} files to download.`); - targets.forEach((f, index) => { - setTimeout(() => { - const a = document.createElement('a'); - // Prefer url (frontend-accessible link), otherwise try data - a.href = f.url || f.data; - a.download = f.orig_name || "download"; - a.style.display = 'none'; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - }, index * 1000); // 300ms interval to avoid browser blocking - }); -} -""" - for btn_idx in range(1, 9): - results_section[f"save_btn_{btn_idx}"].click( - fn=None, - inputs=[ - results_section[f"generated_audio_{btn_idx}"], - results_section["generated_audio_batch"], - ], - js=download_existing_js # Run the above JS + register_generation_metadata_file_handlers( + wiring_context, + auto_checkbox_inputs=auto_checkbox_inputs, + auto_checkbox_outputs=auto_checkbox_outputs, ) + register_results_save_button_handlers(wiring_context) register_results_aux_handlers( wiring_context, mode_ui_outputs=mode_ui_outputs, ) register_generation_run_handlers(wiring_context) register_generation_batch_navigation_handlers(wiring_context) - - # ========== Restore Parameters Handler ========== - results_section["restore_params_btn"].click( - fn=res_h.restore_batch_parameters, - inputs=[ - results_section["current_batch_index"], - results_section["batch_queue"] - ], - outputs=[ - generation_section["text2music_audio_code_string"], - generation_section["captions"], - generation_section["lyrics"], - generation_section["bpm"], - generation_section["key_scale"], - generation_section["time_signature"], - generation_section["vocal_language"], - generation_section["audio_duration"], - generation_section["batch_size_input"], - generation_section["inference_steps"], - generation_section["lm_temperature"], - generation_section["lm_cfg_scale"], - generation_section["lm_top_k"], - generation_section["lm_top_p"], - generation_section["think_checkbox"], - generation_section["use_cot_caption"], - generation_section["use_cot_language"], - generation_section["allow_lm_batch"], - generation_section["track_name"], - generation_section["complete_track_classes"], - generation_section["enable_normalization"], - generation_section["normalization_db"], - generation_section["latent_shift"], - generation_section["latent_rescale"], - ] - ) - - # ========== LRC Display Change Handlers ========== - # NEW APPROACH: Use lrc_display.change() to update audio subtitles - # This decouples audio value updates from subtitle updates, avoiding flickering. - # - # When lrc_display text changes (from generate, LRC button, or manual edit): - # 1. lrc_display.change() is triggered - # 2. update_audio_subtitles_from_lrc() parses LRC and updates audio subtitles - # 3. Audio value is NEVER updated here - only subtitles - for lrc_idx in range(1, 9): - results_section[f"lrc_display_{lrc_idx}"].change( - fn=res_h.update_audio_subtitles_from_lrc, - inputs=[ - results_section[f"lrc_display_{lrc_idx}"], - # audio_duration not needed - parse_lrc_to_subtitles calculates end time from timestamps - ], - outputs=[ - results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value - ] - ) + register_results_restore_and_lrc_handlers(wiring_context) def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section): @@ -280,429 +103,30 @@ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_secti llm_handler=llm_handler, training_section=training_section, ) - training_section = training_context.training_section # ========== Load Existing Dataset (Top Section) ========== # Load existing dataset JSON at the top of Dataset Builder - training_section["load_json_btn"].click( - fn=train_h.load_existing_dataset_for_preprocess, - inputs=[ - training_section["load_json_path"], - training_section["dataset_builder_state"], - ], - outputs=[ - training_section["load_json_status"], - training_section["audio_files_table"], - training_section["sample_selector"], - training_section["dataset_builder_state"], - # Also update preview fields with first sample - training_section["preview_audio"], - training_section["preview_filename"], - training_section["edit_caption"], - training_section["edit_genre"], - training_section["prompt_override"], - training_section["edit_lyrics"], - training_section["edit_bpm"], - training_section["edit_keyscale"], - training_section["edit_timesig"], - training_section["edit_duration"], - training_section["edit_language"], - training_section["edit_instrumental"], - training_section["raw_lyrics_display"], - training_section["has_raw_lyrics_state"], - # Update dataset-level settings - training_section["dataset_name"], - training_section["custom_tag"], - training_section["tag_position"], - training_section["all_instrumental"], - training_section["genre_ratio"], - ] - ).then( - fn=lambda has_raw: gr.update(visible=has_raw), - inputs=[training_section["has_raw_lyrics_state"]], - outputs=[training_section["raw_lyrics_display"]], + register_training_dataset_load_handler( + training_context, + button_key="load_json_btn", + path_key="load_json_path", + status_key="load_json_status", ) - # ========== Dataset Builder Handlers ========== - - # Scan directory for audio files - training_section["scan_btn"].click( - fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory( - dir, name, tag, pos, instr, state - ), - inputs=[ - training_section["audio_directory"], - training_section["dataset_name"], - training_section["custom_tag"], - training_section["tag_position"], - training_section["all_instrumental"], - training_section["dataset_builder_state"], - ], - outputs=[ - training_section["audio_files_table"], - training_section["scan_status"], - training_section["sample_selector"], - training_section["dataset_builder_state"], - ] - ) - - # Auto-label all samples - training_section["auto_label_btn"].click( - fn=lambda state, skip, fmt_lyrics, trans_lyrics, only_unlab: train_h.auto_label_all( - dit_handler, llm_handler, state, skip, fmt_lyrics, trans_lyrics, only_unlab - ), - inputs=[ - training_section["dataset_builder_state"], - training_section["skip_metas"], - training_section["format_lyrics"], - training_section["transcribe_lyrics"], - training_section["only_unlabeled"], - ], - outputs=[ - training_section["audio_files_table"], - training_section["label_progress"], - training_section["dataset_builder_state"], - ] - ).then( - # Refresh preview/edit fields after labeling completes - fn=train_h.get_sample_preview, - inputs=[ - training_section["sample_selector"], - training_section["dataset_builder_state"], - ], - outputs=[ - training_section["preview_audio"], - training_section["preview_filename"], - training_section["edit_caption"], - training_section["edit_genre"], - training_section["prompt_override"], - training_section["edit_lyrics"], - training_section["edit_bpm"], - training_section["edit_keyscale"], - training_section["edit_timesig"], - training_section["edit_duration"], - training_section["edit_language"], - training_section["edit_instrumental"], - training_section["raw_lyrics_display"], - training_section["has_raw_lyrics_state"], - ] - ).then( - fn=lambda status: f"{status or 'āœ… Auto-label complete.'}\nāœ… Preview refreshed.", - inputs=[training_section["label_progress"]], - outputs=[training_section["label_progress"]], - ).then( - fn=lambda has_raw: gr.update(visible=bool(has_raw)), - inputs=[training_section["has_raw_lyrics_state"]], - outputs=[training_section["raw_lyrics_display"]], - ) - - # Mutual exclusion: format_lyrics and transcribe_lyrics cannot both be True - training_section["format_lyrics"].change( - fn=lambda fmt: gr.update(value=False) if fmt else gr.update(), - inputs=[training_section["format_lyrics"]], - outputs=[training_section["transcribe_lyrics"]] - ) - - training_section["transcribe_lyrics"].change( - fn=lambda trans: gr.update(value=False) if trans else gr.update(), - inputs=[training_section["transcribe_lyrics"]], - outputs=[training_section["format_lyrics"]] - ) - - # Sample selector change - update preview - training_section["sample_selector"].change( - fn=train_h.get_sample_preview, - inputs=[ - training_section["sample_selector"], - training_section["dataset_builder_state"], - ], - outputs=[ - training_section["preview_audio"], - training_section["preview_filename"], - training_section["edit_caption"], - training_section["edit_genre"], - training_section["prompt_override"], - training_section["edit_lyrics"], - training_section["edit_bpm"], - training_section["edit_keyscale"], - training_section["edit_timesig"], - training_section["edit_duration"], - training_section["edit_language"], - training_section["edit_instrumental"], - training_section["raw_lyrics_display"], - training_section["has_raw_lyrics_state"], - ] - ).then( - # Show/hide raw lyrics panel based on whether raw lyrics exist - fn=lambda has_raw: gr.update(visible=has_raw), - inputs=[training_section["has_raw_lyrics_state"]], - outputs=[training_section["raw_lyrics_display"]], - ) - - # Save sample edit - training_section["save_edit_btn"].click( - fn=train_h.save_sample_edit, - inputs=[ - training_section["sample_selector"], - training_section["edit_caption"], - training_section["edit_genre"], - training_section["prompt_override"], - training_section["edit_lyrics"], - training_section["edit_bpm"], - training_section["edit_keyscale"], - training_section["edit_timesig"], - training_section["edit_language"], - training_section["edit_instrumental"], - training_section["dataset_builder_state"], - ], - outputs=[ - training_section["audio_files_table"], - training_section["edit_status"], - training_section["dataset_builder_state"], - ] - ) - - # Update settings when changed (including genre_ratio) - for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"], training_section["genre_ratio"]]: - trigger.change( - fn=train_h.update_settings, - inputs=[ - training_section["custom_tag"], - training_section["tag_position"], - training_section["all_instrumental"], - training_section["genre_ratio"], - training_section["dataset_builder_state"], - ], - outputs=[training_section["dataset_builder_state"]] - ) + register_training_dataset_builder_handlers(training_context) - # Save dataset - training_section["save_dataset_btn"].click( - fn=train_h.save_dataset, - inputs=[ - training_section["save_path"], - training_section["dataset_name"], - training_section["dataset_builder_state"], - ], - outputs=[ - training_section["save_status"], - training_section["save_path"], - ] - ) - # ========== Preprocess Handlers ========== # Load existing dataset JSON for preprocessing # This also updates the preview section so users can view/edit samples - training_section["load_existing_dataset_btn"].click( - fn=train_h.load_existing_dataset_for_preprocess, - inputs=[ - training_section["load_existing_dataset_path"], - training_section["dataset_builder_state"], - ], - outputs=[ - training_section["load_existing_status"], - training_section["audio_files_table"], - training_section["sample_selector"], - training_section["dataset_builder_state"], - # Also update preview fields with first sample - training_section["preview_audio"], - training_section["preview_filename"], - training_section["edit_caption"], - training_section["edit_genre"], - training_section["prompt_override"], - training_section["edit_lyrics"], - training_section["edit_bpm"], - training_section["edit_keyscale"], - training_section["edit_timesig"], - training_section["edit_duration"], - training_section["edit_language"], - training_section["edit_instrumental"], - training_section["raw_lyrics_display"], - training_section["has_raw_lyrics_state"], - # Update dataset-level settings - training_section["dataset_name"], - training_section["custom_tag"], - training_section["tag_position"], - training_section["all_instrumental"], - training_section["genre_ratio"], - ] - ).then( - fn=lambda has_raw: gr.update(visible=has_raw), - inputs=[training_section["has_raw_lyrics_state"]], - outputs=[training_section["raw_lyrics_display"]], + register_training_dataset_load_handler( + training_context, + button_key="load_existing_dataset_btn", + path_key="load_existing_dataset_path", + status_key="load_existing_status", ) # Preprocess dataset to tensor files - training_section["preprocess_btn"].click( - fn=lambda output_dir, mode, state: train_h.preprocess_dataset( - output_dir, mode, dit_handler, state - ), - inputs=[ - training_section["preprocess_output_dir"], - training_section["preprocess_mode"], - training_section["dataset_builder_state"], - ], - outputs=[training_section["preprocess_progress"]] - ) - - # ========== Training Tab Handlers ========== - - # Load preprocessed tensor dataset - training_section["load_dataset_btn"].click( - fn=train_h.load_training_dataset, - inputs=[training_section["training_tensor_dir"]], - outputs=[training_section["training_dataset_info"]] - ) - - # Start training from preprocessed tensors - def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts): - """Stream LoRA training progress and normalize failure outputs for the UI.""" - from loguru import logger - if not isinstance(ts, dict): - ts = {"is_training": False, "should_stop": False} - try: - for progress, log_msg, plot, state in train_h.start_training( - tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts - ): - yield progress, log_msg, plot, state - except Exception as e: - logger.exception("Training wrapper error") - yield f"āŒ Error: {str(e)}", str(e), None, ts - - training_section["start_training_btn"].click( - fn=training_wrapper, - inputs=[ - training_section["training_tensor_dir"], - training_section["lora_rank"], - training_section["lora_alpha"], - training_section["lora_dropout"], - training_section["learning_rate"], - training_section["train_epochs"], - training_section["train_batch_size"], - training_section["gradient_accumulation"], - training_section["save_every_n_epochs"], - training_section["training_shift"], - training_section["training_seed"], - training_section["lora_output_dir"], - training_section["resume_checkpoint_dir"], - training_section["training_state"], - ], - outputs=[ - training_section["training_progress"], - training_section["training_log"], - training_section["training_loss_plot"], - training_section["training_state"], - ] - ) - - # Stop training - training_section["stop_training_btn"].click( - fn=train_h.stop_training, - inputs=[training_section["training_state"]], - outputs=[ - training_section["training_progress"], - training_section["training_state"], - ] - ) - - # Export LoRA - training_section["export_lora_btn"].click( - fn=train_h.export_lora, - inputs=[ - training_section["export_path"], - training_section["lora_output_dir"], - ], - outputs=[training_section["export_status"]] - ) - - # ========== LoKr Training Tab Handlers ========== - - # Load preprocessed tensor dataset for LoKr - training_section["lokr_load_dataset_btn"].click( - fn=train_h.load_training_dataset, - inputs=[training_section["lokr_training_tensor_dir"]], - outputs=[training_section["lokr_training_dataset_info"]] - ) - - # Start LoKr training from preprocessed tensors - def lokr_training_wrapper( - tensor_dir, ldim, lalpha, factor, decompose_both, use_tucker, - use_scalar, weight_decompose, lr, ep, bs, ga, se, sh, sd, od, ts, - ): - """Stream LoKr training progress and normalize failure outputs for the UI.""" - from loguru import logger - if not isinstance(ts, dict): - ts = {"is_training": False, "should_stop": False} - try: - for progress, log_msg, plot, state in train_h.start_lokr_training( - tensor_dir, dit_handler, - ldim, lalpha, factor, decompose_both, use_tucker, - use_scalar, weight_decompose, - lr, ep, bs, ga, se, sh, sd, od, ts, - ): - yield progress, log_msg, plot, state - except Exception as e: - logger.exception("LoKr training wrapper error") - yield f"āŒ Error: {str(e)}", str(e), None, ts - - training_section["start_lokr_training_btn"].click( - fn=lokr_training_wrapper, - inputs=[ - training_section["lokr_training_tensor_dir"], - training_section["lokr_linear_dim"], - training_section["lokr_linear_alpha"], - training_section["lokr_factor"], - training_section["lokr_decompose_both"], - training_section["lokr_use_tucker"], - training_section["lokr_use_scalar"], - training_section["lokr_weight_decompose"], - training_section["lokr_learning_rate"], - training_section["lokr_train_epochs"], - training_section["lokr_train_batch_size"], - training_section["lokr_gradient_accumulation"], - training_section["lokr_save_every_n_epochs"], - training_section["lokr_training_shift"], - training_section["lokr_training_seed"], - training_section["lokr_output_dir"], - training_section["training_state"], - ], - outputs=[ - training_section["lokr_training_progress"], - training_section["lokr_training_log"], - training_section["lokr_training_loss_plot"], - training_section["training_state"], - ] - ) - - # Stop LoKr training (reuses same stop mechanism) - training_section["stop_lokr_training_btn"].click( - fn=train_h.stop_training, - inputs=[training_section["training_state"]], - outputs=[ - training_section["lokr_training_progress"], - training_section["training_state"], - ] - ) - - # Refresh LoKr export epochs - training_section["refresh_lokr_export_epochs_btn"].click( - fn=train_h.list_lokr_export_epochs, - inputs=[training_section["lokr_output_dir"]], - outputs=[ - training_section["lokr_export_epoch"], - training_section["lokr_export_status"], - ] - ) - - # Export LoKr - training_section["export_lokr_btn"].click( - fn=train_h.export_lokr, - inputs=[ - training_section["lokr_export_path"], - training_section["lokr_output_dir"], - training_section["lokr_export_epoch"], - ], - outputs=[training_section["lokr_export_status"]] - ) + register_training_preprocess_handler(training_context) + register_training_run_handlers(training_context) diff --git a/acestep/ui/gradio/events/wiring/__init__.py b/acestep/ui/gradio/events/wiring/__init__.py index e1695793..1953fe67 100644 --- a/acestep/ui/gradio/events/wiring/__init__.py +++ b/acestep/ui/gradio/events/wiring/__init__.py @@ -12,11 +12,22 @@ build_mode_ui_outputs, ) from .generation_metadata_wiring import register_generation_metadata_handlers +from .generation_metadata_file_wiring import register_generation_metadata_file_handlers from .generation_batch_navigation_wiring import register_generation_batch_navigation_handlers from .generation_mode_wiring import register_generation_mode_handlers from .generation_run_wiring import register_generation_run_handlers from .results_aux_wiring import register_results_aux_handlers +from .results_display_wiring import ( + register_results_restore_and_lrc_handlers, + register_results_save_button_handlers, +) from .generation_service_wiring import register_generation_service_handlers +from .training_dataset_builder_wiring import register_training_dataset_builder_handlers +from .training_dataset_preprocess_wiring import ( + register_training_dataset_load_handler, + register_training_preprocess_handler, +) +from .training_run_wiring import register_training_run_handlers __all__ = [ "GenerationWiringContext", @@ -25,9 +36,16 @@ "build_auto_checkbox_outputs", "build_mode_ui_outputs", "register_generation_batch_navigation_handlers", + "register_generation_metadata_file_handlers", "register_generation_metadata_handlers", "register_generation_mode_handlers", "register_generation_run_handlers", "register_results_aux_handlers", + "register_results_restore_and_lrc_handlers", + "register_results_save_button_handlers", "register_generation_service_handlers", + "register_training_dataset_builder_handlers", + "register_training_dataset_load_handler", + "register_training_preprocess_handler", + "register_training_run_handlers", ] diff --git a/acestep/ui/gradio/events/wiring/ast_test_utils.py b/acestep/ui/gradio/events/wiring/ast_test_utils.py new file mode 100644 index 00000000..372139d6 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/ast_test_utils.py @@ -0,0 +1,18 @@ +"""Shared AST parsing helpers for wiring contract tests.""" + +import ast +from pathlib import Path + + +def load_module_ast(module_path: Path) -> ast.Module: + """Return the parsed AST module for the provided source path.""" + + return ast.parse(module_path.read_text(encoding="utf-8")) + + +def subscript_key(node: ast.Subscript) -> str | None: + """Return constant key value from a simple subscript expression.""" + + if isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str): + return node.slice.value + return None diff --git a/acestep/ui/gradio/events/wiring/decomposition_contract_test.py b/acestep/ui/gradio/events/wiring/decomposition_contract_generation_test.py similarity index 50% rename from acestep/ui/gradio/events/wiring/decomposition_contract_test.py rename to acestep/ui/gradio/events/wiring/decomposition_contract_generation_test.py index f46ae6d7..4851ce08 100644 --- a/acestep/ui/gradio/events/wiring/decomposition_contract_test.py +++ b/acestep/ui/gradio/events/wiring/decomposition_contract_generation_test.py @@ -1,102 +1,71 @@ -"""Regression tests for event wiring decomposition contracts. - -These tests validate source-level delegation in -``acestep.ui.gradio.events.__init__`` without importing Gradio dependencies. -""" +"""Generation-focused decomposition contract tests.""" import ast -from pathlib import Path import unittest - -_EVENTS_INIT_PATH = Path(__file__).resolve().parents[1] / "__init__.py" -_MODE_WIRING_PATH = Path(__file__).resolve().with_name("generation_mode_wiring.py") -_RUN_WIRING_PATH = Path(__file__).resolve().with_name("generation_run_wiring.py") -_BATCH_NAV_WIRING_PATH = Path(__file__).resolve().with_name( - "generation_batch_navigation_wiring.py" -) - - -def _load_setup_event_handlers_node() -> ast.FunctionDef: - """Return the AST node for ``setup_event_handlers``.""" - - source = _EVENTS_INIT_PATH.read_text(encoding="utf-8") - module = ast.parse(source) - for node in module.body: - if isinstance(node, ast.FunctionDef) and node.name == "setup_event_handlers": - return node - raise AssertionError("setup_event_handlers not found") - - -def _load_generation_mode_wiring_node() -> ast.FunctionDef: - """Return the AST node for ``register_generation_mode_handlers``.""" - - source = _MODE_WIRING_PATH.read_text(encoding="utf-8") - module = ast.parse(source) - for node in module.body: - if isinstance(node, ast.FunctionDef) and node.name == "register_generation_mode_handlers": - return node - raise AssertionError("register_generation_mode_handlers not found") - - -def _load_generation_run_wiring_node() -> ast.FunctionDef: - """Return the AST node for ``register_generation_run_handlers``.""" - - source = _RUN_WIRING_PATH.read_text(encoding="utf-8") - module = ast.parse(source) - for node in module.body: - if isinstance(node, ast.FunctionDef) and node.name == "register_generation_run_handlers": - return node - raise AssertionError("register_generation_run_handlers not found") - - -def _load_generation_batch_navigation_wiring_node() -> ast.FunctionDef: - """Return the AST node for ``register_generation_batch_navigation_handlers``.""" - - source = _BATCH_NAV_WIRING_PATH.read_text(encoding="utf-8") - module = ast.parse(source) - for node in module.body: - if isinstance(node, ast.FunctionDef) and node.name == "register_generation_batch_navigation_handlers": - return node - raise AssertionError("register_generation_batch_navigation_handlers not found") - - -def _call_name(node: ast.AST) -> str | None: - """Extract a simple function name from a call node target.""" - - if isinstance(node, ast.Name): - return node.id - if isinstance(node, ast.Attribute): - return node.attr - return None - - -class DecompositionContractTests(unittest.TestCase): - """Verify delegation contracts introduced in PR2/PR3/PR4/PR5/PR6 extraction.""" +try: + from .decomposition_contract_helpers import ( + call_name, + load_generation_batch_navigation_wiring_node, + load_generation_metadata_file_wiring_module, + load_generation_mode_wiring_node, + load_generation_run_wiring_node, + load_results_display_wiring_module, + load_setup_event_handlers_node, + ) +except ImportError: # pragma: no cover - supports direct file execution + from decomposition_contract_helpers import ( + call_name, + load_generation_batch_navigation_wiring_node, + load_generation_metadata_file_wiring_module, + load_generation_mode_wiring_node, + load_generation_run_wiring_node, + load_results_display_wiring_module, + load_setup_event_handlers_node, + ) + + +class DecompositionContractGenerationTests(unittest.TestCase): + """Verify generation-side delegation contracts for event wiring extraction.""" def test_setup_event_handlers_uses_generation_wiring_helpers(self): """setup_event_handlers should delegate generation wiring registration.""" - setup_node = _load_setup_event_handlers_node() + setup_node = load_setup_event_handlers_node() call_names = [] for node in ast.walk(setup_node): if isinstance(node, ast.Call): - name = _call_name(node.func) + name = call_name(node.func) if name: call_names.append(name) self.assertIn("register_generation_service_handlers", call_names) self.assertIn("register_generation_batch_navigation_handlers", call_names) + self.assertIn("register_generation_metadata_file_handlers", call_names) self.assertIn("register_generation_metadata_handlers", call_names) self.assertIn("register_generation_mode_handlers", call_names) self.assertIn("register_generation_run_handlers", call_names) self.assertIn("register_results_aux_handlers", call_names) + self.assertIn("register_results_save_button_handlers", call_names) + self.assertIn("register_results_restore_and_lrc_handlers", call_names) self.assertIn("build_mode_ui_outputs", call_names) + def test_generation_metadata_file_wiring_calls_expected_handlers(self): + """Metadata file wiring should call load-metadata and auto-uncheck handlers.""" + + wiring_node = load_generation_metadata_file_wiring_module() + attribute_names = [] + for node in ast.walk(wiring_node): + if isinstance(node, ast.Attribute): + attribute_names.append(node.attr) + + self.assertIn("load_metadata", attribute_names) + self.assertIn("uncheck_auto_for_populated_fields", attribute_names) + def test_generation_mode_wiring_uses_mode_ui_outputs_variable(self): """Mode wiring helper should bind generation_mode outputs to mode_ui_outputs.""" - wiring_node = _load_generation_mode_wiring_node() + wiring_node = load_generation_mode_wiring_node() found_mode_change_output_binding = False for node in ast.walk(wiring_node): @@ -118,12 +87,12 @@ def test_generation_mode_wiring_uses_mode_ui_outputs_variable(self): def test_generation_run_wiring_calls_expected_results_handlers(self): """Run wiring should call clear, generate stream, and background pre-generation helpers.""" - wiring_node = _load_generation_run_wiring_node() + wiring_node = load_generation_run_wiring_node() call_names = [] attribute_names = [] for node in ast.walk(wiring_node): if isinstance(node, ast.Call): - name = _call_name(node.func) + name = call_name(node.func) if name: call_names.append(name) if isinstance(node, ast.Attribute): @@ -136,12 +105,12 @@ def test_generation_run_wiring_calls_expected_results_handlers(self): def test_batch_navigation_wiring_calls_expected_results_handlers(self): """Batch navigation wiring should call previous/next/background results helpers.""" - wiring_node = _load_generation_batch_navigation_wiring_node() + wiring_node = load_generation_batch_navigation_wiring_node() call_names = [] attribute_names = [] for node in ast.walk(wiring_node): if isinstance(node, ast.Call): - name = _call_name(node.func) + name = call_name(node.func) if name: call_names.append(name) if isinstance(node, ast.Attribute): @@ -152,6 +121,18 @@ def test_batch_navigation_wiring_calls_expected_results_handlers(self): self.assertIn("navigate_to_next_batch", attribute_names) self.assertIn("generate_next_batch_background", call_names) + def test_results_display_wiring_calls_expected_results_handlers(self): + """Results display wiring should call restore and LRC subtitle handlers.""" + + wiring_node = load_results_display_wiring_module() + attribute_names = [] + for node in ast.walk(wiring_node): + if isinstance(node, ast.Attribute): + attribute_names.append(node.attr) + + self.assertIn("restore_batch_parameters", attribute_names) + self.assertIn("update_audio_subtitles_from_lrc", attribute_names) + if __name__ == "__main__": unittest.main() diff --git a/acestep/ui/gradio/events/wiring/decomposition_contract_helpers.py b/acestep/ui/gradio/events/wiring/decomposition_contract_helpers.py new file mode 100644 index 00000000..90f6c556 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/decomposition_contract_helpers.py @@ -0,0 +1,133 @@ +"""Shared AST helpers for decomposition contract tests.""" + +import ast +from pathlib import Path + + +_EVENTS_INIT_PATH = Path(__file__).resolve().parents[1] / "__init__.py" +_MODE_WIRING_PATH = Path(__file__).resolve().with_name("generation_mode_wiring.py") +_METADATA_FILE_WIRING_PATH = Path(__file__).resolve().with_name( + "generation_metadata_file_wiring.py" +) +_RUN_WIRING_PATH = Path(__file__).resolve().with_name("generation_run_wiring.py") +_BATCH_NAV_WIRING_PATH = Path(__file__).resolve().with_name( + "generation_batch_navigation_wiring.py" +) +_RESULTS_DISPLAY_WIRING_PATH = Path(__file__).resolve().with_name( + "results_display_wiring.py" +) +_TRAINING_DATASET_BUILDER_WIRING_PATH = Path(__file__).resolve().with_name( + "training_dataset_builder_wiring.py" +) +_TRAINING_DATASET_PREPROCESS_WIRING_PATH = Path(__file__).resolve().with_name( + "training_dataset_preprocess_wiring.py" +) +_TRAINING_RUN_WIRING_PATH = Path(__file__).resolve().with_name("training_run_wiring.py") +_TRAINING_LOKR_WIRING_PATH = Path(__file__).resolve().with_name("training_lokr_wiring.py") + + +def load_setup_event_handlers_node() -> ast.FunctionDef: + """Return the AST node for ``setup_event_handlers``.""" + + source = _EVENTS_INIT_PATH.read_text(encoding="utf-8") + module = ast.parse(source) + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == "setup_event_handlers": + return node + raise AssertionError("setup_event_handlers not found") + + +def load_setup_training_event_handlers_node() -> ast.FunctionDef: + """Return the AST node for ``setup_training_event_handlers``.""" + + source = _EVENTS_INIT_PATH.read_text(encoding="utf-8") + module = ast.parse(source) + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == "setup_training_event_handlers": + return node + raise AssertionError("setup_training_event_handlers not found") + + +def load_generation_mode_wiring_node() -> ast.FunctionDef: + """Return the AST node for ``register_generation_mode_handlers``.""" + + source = _MODE_WIRING_PATH.read_text(encoding="utf-8") + module = ast.parse(source) + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == "register_generation_mode_handlers": + return node + raise AssertionError("register_generation_mode_handlers not found") + + +def load_generation_metadata_file_wiring_module() -> ast.Module: + """Return the parsed AST module for metadata file-load wiring.""" + + source = _METADATA_FILE_WIRING_PATH.read_text(encoding="utf-8") + return ast.parse(source) + + +def load_generation_run_wiring_node() -> ast.FunctionDef: + """Return the AST node for ``register_generation_run_handlers``.""" + + source = _RUN_WIRING_PATH.read_text(encoding="utf-8") + module = ast.parse(source) + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == "register_generation_run_handlers": + return node + raise AssertionError("register_generation_run_handlers not found") + + +def load_generation_batch_navigation_wiring_node() -> ast.FunctionDef: + """Return the AST node for ``register_generation_batch_navigation_handlers``.""" + + source = _BATCH_NAV_WIRING_PATH.read_text(encoding="utf-8") + module = ast.parse(source) + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == "register_generation_batch_navigation_handlers": + return node + raise AssertionError("register_generation_batch_navigation_handlers not found") + + +def load_results_display_wiring_module() -> ast.Module: + """Return the parsed AST module for results display/save wiring.""" + + source = _RESULTS_DISPLAY_WIRING_PATH.read_text(encoding="utf-8") + return ast.parse(source) + + +def load_training_run_wiring_module() -> ast.Module: + """Return the parsed AST module for ``training_run_wiring.py``.""" + + source = _TRAINING_RUN_WIRING_PATH.read_text(encoding="utf-8") + return ast.parse(source) + + +def load_training_lokr_wiring_module() -> ast.Module: + """Return the parsed AST module for ``training_lokr_wiring.py``.""" + + source = _TRAINING_LOKR_WIRING_PATH.read_text(encoding="utf-8") + return ast.parse(source) + + +def load_training_dataset_preprocess_wiring_module() -> ast.Module: + """Return the parsed AST module for training dataset/preprocess wiring.""" + + source = _TRAINING_DATASET_PREPROCESS_WIRING_PATH.read_text(encoding="utf-8") + return ast.parse(source) + + +def load_training_dataset_builder_wiring_module() -> ast.Module: + """Return the parsed AST module for training dataset-builder wiring.""" + + source = _TRAINING_DATASET_BUILDER_WIRING_PATH.read_text(encoding="utf-8") + return ast.parse(source) + + +def call_name(node: ast.AST) -> str | None: + """Extract a simple function name from a call node target.""" + + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + return None diff --git a/acestep/ui/gradio/events/wiring/decomposition_contract_training_test.py b/acestep/ui/gradio/events/wiring/decomposition_contract_training_test.py new file mode 100644 index 00000000..c4f27732 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/decomposition_contract_training_test.py @@ -0,0 +1,112 @@ +"""Training-focused decomposition contract tests.""" + +import ast +import unittest + +try: + from .decomposition_contract_helpers import ( + call_name, + load_setup_training_event_handlers_node, + load_training_dataset_builder_wiring_module, + load_training_dataset_preprocess_wiring_module, + load_training_lokr_wiring_module, + load_training_run_wiring_module, + ) +except ImportError: # pragma: no cover - supports direct file execution + from decomposition_contract_helpers import ( + call_name, + load_setup_training_event_handlers_node, + load_training_dataset_builder_wiring_module, + load_training_dataset_preprocess_wiring_module, + load_training_lokr_wiring_module, + load_training_run_wiring_module, + ) + + +class DecompositionContractTrainingTests(unittest.TestCase): + """Verify training-side delegation contracts for event wiring extraction.""" + + def test_setup_training_event_handlers_uses_training_run_wiring_helper(self): + """setup_training_event_handlers should delegate run-tab wiring registration.""" + + setup_node = load_setup_training_event_handlers_node() + call_names = [] + for node in ast.walk(setup_node): + if isinstance(node, ast.Call): + name = call_name(node.func) + if name: + call_names.append(name) + + self.assertIn("register_training_run_handlers", call_names) + self.assertIn("register_training_dataset_builder_handlers", call_names) + self.assertIn("register_training_dataset_load_handler", call_names) + self.assertIn("register_training_preprocess_handler", call_names) + + def test_training_run_wiring_calls_expected_training_handlers(self): + """Training run wiring should invoke both LoRA and LoKr training entry points.""" + + training_run_node = load_training_run_wiring_module() + lokr_node = load_training_lokr_wiring_module() + + training_run_call_names = [] + for node in ast.walk(training_run_node): + if isinstance(node, ast.Call): + name = call_name(node.func) + if name: + training_run_call_names.append(name) + + lokr_call_names = [] + lokr_attribute_names = [] + for node in ast.walk(lokr_node): + if isinstance(node, ast.Call): + name = call_name(node.func) + if name: + lokr_call_names.append(name) + if isinstance(node, ast.Attribute): + lokr_attribute_names.append(node.attr) + + self.assertIn("start_training", training_run_call_names) + self.assertIn("register_lokr_training_handlers", training_run_call_names) + self.assertIn("start_lokr_training", lokr_call_names) + self.assertIn("stop_training", lokr_attribute_names) + + def test_training_dataset_builder_wiring_calls_expected_handlers(self): + """Dataset-builder wiring should call scan/label/edit/settings/save handlers.""" + + wiring_node = load_training_dataset_builder_wiring_module() + call_names = [] + attribute_names = [] + for node in ast.walk(wiring_node): + if isinstance(node, ast.Call): + name = call_name(node.func) + if name: + call_names.append(name) + if isinstance(node, ast.Attribute): + attribute_names.append(node.attr) + + self.assertIn("scan_directory", call_names) + self.assertIn("auto_label_all", call_names) + self.assertIn("save_sample_edit", attribute_names) + self.assertIn("update_settings", attribute_names) + self.assertIn("save_dataset", attribute_names) + + def test_training_dataset_preprocess_wiring_calls_expected_handlers(self): + """Dataset/preprocess wiring should call existing training handler entry points.""" + + wiring_node = load_training_dataset_preprocess_wiring_module() + call_names = [] + attribute_names = [] + for node in ast.walk(wiring_node): + if isinstance(node, ast.Call): + name = call_name(node.func) + if name: + call_names.append(name) + if isinstance(node, ast.Attribute): + attribute_names.append(node.attr) + + self.assertIn("load_existing_dataset_for_preprocess", attribute_names) + self.assertIn("preprocess_dataset", call_names) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/ui/gradio/events/wiring/docstring_coverage_test.py b/acestep/ui/gradio/events/wiring/docstring_coverage_test.py new file mode 100644 index 00000000..e8b9b159 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/docstring_coverage_test.py @@ -0,0 +1,61 @@ +"""Docstring coverage tests for decomposed event-wiring modules.""" + +import ast +from pathlib import Path +import unittest + + +_MODULE_PATHS = [ + Path(__file__).resolve().parents[1] / "__init__.py", + Path(__file__).resolve().with_name("generation_metadata_file_wiring.py"), + Path(__file__).resolve().with_name("results_display_wiring.py"), + Path(__file__).resolve().with_name("training_dataset_builder_wiring.py"), + Path(__file__).resolve().with_name("training_dataset_preprocess_wiring.py"), + Path(__file__).resolve().with_name("training_run_wiring.py"), + Path(__file__).resolve().with_name("training_lokr_wiring.py"), +] + + +def _collect_nodes_missing_docstrings(module: ast.Module) -> list[str]: + """Return qualified names for functions/classes missing docstrings.""" + + missing: list[str] = [] + + def visit(node: ast.AST, prefix: str = "") -> None: + for child in ast.iter_child_nodes(node): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + name = f"{prefix}{child.name}" + if ast.get_docstring(child) is None: + missing.append(name) + visit(child, f"{name}.") + else: + visit(child, prefix) + + visit(module) + return missing + + +class DocstringCoverageTests(unittest.TestCase): + """Ensure decomposed event-wiring modules keep full docstring coverage.""" + + def test_module_and_symbol_docstrings_are_present(self): + """Each target module and all nested defs/classes should have docstrings.""" + + failures: list[str] = [] + for module_path in _MODULE_PATHS: + source = module_path.read_text(encoding="utf-8") + tree = ast.parse(source) + try: + module_name = str(module_path.relative_to(Path.cwd())) + except ValueError: + module_name = module_path.name + if ast.get_docstring(tree) is None: + failures.append(f"{module_name}: ") + for symbol in _collect_nodes_missing_docstrings(tree): + failures.append(f"{module_name}: {symbol}") + + self.assertEqual(failures, [], f"Missing docstrings: {failures}") + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/ui/gradio/events/wiring/generation_metadata_file_wiring.py b/acestep/ui/gradio/events/wiring/generation_metadata_file_wiring.py new file mode 100644 index 00000000..48163bb9 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/generation_metadata_file_wiring.py @@ -0,0 +1,81 @@ +"""Generation metadata file-load wiring helpers.""" + +from typing import Any, Sequence + +from .. import generation_handlers as gen_h +from .context import GenerationWiringContext + + +_LOAD_METADATA_GENERATION_OUTPUT_KEYS = ( + "task_type", + "captions", + "lyrics", + "vocal_language", + "bpm", + "key_scale", + "time_signature", + "audio_duration", + "batch_size_input", + "inference_steps", + "guidance_scale", + "seed", + "random_seed_checkbox", + "use_adg", + "cfg_interval_start", + "cfg_interval_end", + "shift", + "infer_method", + "custom_timesteps", + "audio_format", + "lm_temperature", + "lm_cfg_scale", + "lm_top_k", + "lm_top_p", + "lm_negative_prompt", + "use_cot_metas", + "use_cot_caption", + "use_cot_language", + "audio_cover_strength", + "cover_noise_strength", + "think_checkbox", + "text2music_audio_code_string", + "repainting_start", + "repainting_end", + "track_name", + "complete_track_classes", + "instrumental_checkbox", +) + + +def _build_load_metadata_outputs(context: GenerationWiringContext) -> list[Any]: + """Return ordered outputs for the metadata file-load upload handler.""" + + generation_section = context.generation_section + results_section = context.results_section + outputs = [ + generation_section[key] for key in _LOAD_METADATA_GENERATION_OUTPUT_KEYS + ] + outputs.append(results_section["is_format_caption_state"]) + return outputs + + +def register_generation_metadata_file_handlers( + context: GenerationWiringContext, + *, + auto_checkbox_inputs: Sequence[Any], + auto_checkbox_outputs: Sequence[Any], +) -> None: + """Register metadata load-file upload and auto-checkbox sync handlers.""" + + generation_section = context.generation_section + llm_handler = context.llm_handler + + generation_section["load_file"].upload( + fn=lambda file_obj: gen_h.load_metadata(file_obj, llm_handler), + inputs=[generation_section["load_file"]], + outputs=_build_load_metadata_outputs(context), + ).then( + fn=gen_h.uncheck_auto_for_populated_fields, + inputs=list(auto_checkbox_inputs), + outputs=list(auto_checkbox_outputs), + ) diff --git a/acestep/ui/gradio/events/wiring/generation_metadata_file_wiring_test.py b/acestep/ui/gradio/events/wiring/generation_metadata_file_wiring_test.py new file mode 100644 index 00000000..31afe405 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/generation_metadata_file_wiring_test.py @@ -0,0 +1,112 @@ +"""Unit tests for generation metadata file wiring contracts.""" + +import ast +from pathlib import Path +import unittest + +try: + from .ast_test_utils import load_module_ast +except ImportError: # pragma: no cover - supports direct file execution + from ast_test_utils import load_module_ast + + +_WIRING_PATH = Path(__file__).with_name("generation_metadata_file_wiring.py") + +_EXPECTED_METADATA_KEYS = [ + "task_type", + "captions", + "lyrics", + "vocal_language", + "bpm", + "key_scale", + "time_signature", + "audio_duration", + "batch_size_input", + "inference_steps", + "guidance_scale", + "seed", + "random_seed_checkbox", + "use_adg", + "cfg_interval_start", + "cfg_interval_end", + "shift", + "infer_method", + "custom_timesteps", + "audio_format", + "lm_temperature", + "lm_cfg_scale", + "lm_top_k", + "lm_top_p", + "lm_negative_prompt", + "use_cot_metas", + "use_cot_caption", + "use_cot_language", + "audio_cover_strength", + "cover_noise_strength", + "think_checkbox", + "text2music_audio_code_string", + "repainting_start", + "repainting_end", + "track_name", + "complete_track_classes", + "instrumental_checkbox", +] + +def _tuple_string_values(node: ast.AST) -> list[str]: + """Return string literal values from a tuple/list literal node.""" + + if not isinstance(node, (ast.Tuple, ast.List)): + raise AssertionError("Expected tuple/list node") + values = [] + for element in node.elts: + if not isinstance(element, ast.Constant) or not isinstance(element.value, str): + raise AssertionError("Expected string literal") + values.append(element.value) + return values + + +class GenerationMetadataFileWiringTests(unittest.TestCase): + """Verify metadata file-load wiring ordering and handler contracts.""" + + def test_metadata_output_key_contract_order_is_stable(self): + """The metadata output key tuple should match the expected UI ordering.""" + + module = load_module_ast(_WIRING_PATH) + for node in module.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "_LOAD_METADATA_GENERATION_OUTPUT_KEYS": + self.assertEqual(_tuple_string_values(node.value), _EXPECTED_METADATA_KEYS) + return + self.fail("_LOAD_METADATA_GENERATION_OUTPUT_KEYS not found") + + def test_build_outputs_appends_format_caption_state_last(self): + """build outputs helper should append is_format_caption_state at the tail.""" + + module = load_module_ast(_WIRING_PATH) + for node in module.body: + if not isinstance(node, ast.FunctionDef) or node.name != "_build_load_metadata_outputs": + continue + for inner in ast.walk(node): + if isinstance(inner, ast.Call) and isinstance(inner.func, ast.Attribute): + if inner.func.attr == "append" and inner.args: + arg = inner.args[0] + if isinstance(arg, ast.Subscript) and isinstance(arg.slice, ast.Constant): + self.assertEqual(arg.slice.value, "is_format_caption_state") + return + self.fail("append(results_section['is_format_caption_state']) not found") + + def test_register_function_references_expected_generation_handlers(self): + """Register helper should reference load-metadata and auto-uncheck handlers.""" + + module = load_module_ast(_WIRING_PATH) + attrs = [] + for node in ast.walk(module): + if isinstance(node, ast.Attribute): + attrs.append(node.attr) + self.assertIn("load_metadata", attrs) + self.assertIn("uncheck_auto_for_populated_fields", attrs) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/ui/gradio/events/wiring/results_display_wiring.py b/acestep/ui/gradio/events/wiring/results_display_wiring.py new file mode 100644 index 00000000..60f07e76 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/results_display_wiring.py @@ -0,0 +1,137 @@ +"""Results display/restore/LRC event wiring helpers.""" + +from .context import GenerationWiringContext +from .. import results_handlers as res_h + + +_DOWNLOAD_EXISTING_JS = """(current_audio, batch_files) => { + // Debug: print what the input actually is + console.log("[Debug] Current Audio Input:", current_audio); + + // 1. Safety check + if (!current_audio) { + console.warn("Warning: No audio selected or audio is empty."); + return; + } + if (!batch_files || !Array.isArray(batch_files)) { + console.warn("Warning: Batch file list is empty/not ready."); + return; + } + + // 2. Smartly extract path string + let pathString = ""; + + if (typeof current_audio === "string") { + // Case A: direct path string received + pathString = current_audio; + } else if (typeof current_audio === "object") { + // Case B: an object is received, try common properties + // Gradio file objects usually have path, url, or name + pathString = current_audio.path || current_audio.name || current_audio.url || ""; + } + + if (!pathString) { + console.error("Error: Could not extract a valid path string from input.", current_audio); + return; + } + + // 3. Extract Key (UUID) + // Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3 + let filename = pathString.split(/[\\/]/).pop(); // get the filename + let key = filename.split('.')[0]; // get UUID without extension + + console.log(`Key extracted: ${key}`); + + // 4. Find matching file(s) in the list + let targets = batch_files.filter(f => { + // Also extract names from batch_files objects + // f usually contains name (backend path) and orig_name (download name) + const fPath = f.name || f.path || ""; + return fPath.includes(key); + }); + + if (targets.length === 0) { + console.warn("Warning: No matching files found in batch list for key:", key); + alert("Batch list does not contain this file yet. Please wait for generation to finish."); + return; + } + + // 5. Trigger download(s) + console.log(`Found ${targets.length} files to download.`); + targets.forEach((f, index) => { + setTimeout(() => { + const a = document.createElement('a'); + // Prefer url (frontend-accessible link), otherwise try data + a.href = f.url || f.data; + a.download = f.orig_name || "download"; + a.style.display = 'none'; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + }, index * 1000); // 1000ms interval per index to avoid browser blocking + }); +} +""" + + +def register_results_save_button_handlers(context: GenerationWiringContext) -> None: + """Register save/download button JS handlers for the 8 result slots.""" + + results_section = context.results_section + for btn_idx in range(1, 9): + results_section[f"save_btn_{btn_idx}"].click( + fn=None, + inputs=[ + results_section[f"generated_audio_{btn_idx}"], + results_section["generated_audio_batch"], + ], + js=_DOWNLOAD_EXISTING_JS, + ) + + +def register_results_restore_and_lrc_handlers(context: GenerationWiringContext) -> None: + """Register restore-parameters and LRC subtitle-sync handlers.""" + + generation_section = context.generation_section + results_section = context.results_section + + results_section["restore_params_btn"].click( + fn=res_h.restore_batch_parameters, + inputs=[ + results_section["current_batch_index"], + results_section["batch_queue"], + ], + outputs=[ + generation_section["text2music_audio_code_string"], + generation_section["captions"], + generation_section["lyrics"], + generation_section["bpm"], + generation_section["key_scale"], + generation_section["time_signature"], + generation_section["vocal_language"], + generation_section["audio_duration"], + generation_section["batch_size_input"], + generation_section["inference_steps"], + generation_section["lm_temperature"], + generation_section["lm_cfg_scale"], + generation_section["lm_top_k"], + generation_section["lm_top_p"], + generation_section["think_checkbox"], + generation_section["use_cot_caption"], + generation_section["use_cot_language"], + generation_section["allow_lm_batch"], + generation_section["track_name"], + generation_section["complete_track_classes"], + generation_section["enable_normalization"], + generation_section["normalization_db"], + generation_section["latent_shift"], + generation_section["latent_rescale"], + ], + ) + + for lrc_idx in range(1, 9): + results_section[f"lrc_display_{lrc_idx}"].change( + fn=res_h.update_audio_subtitles_from_lrc, + inputs=[results_section[f"lrc_display_{lrc_idx}"]], + outputs=[results_section[f"generated_audio_{lrc_idx}"]], + ) diff --git a/acestep/ui/gradio/events/wiring/results_display_wiring_test.py b/acestep/ui/gradio/events/wiring/results_display_wiring_test.py new file mode 100644 index 00000000..5c46b55f --- /dev/null +++ b/acestep/ui/gradio/events/wiring/results_display_wiring_test.py @@ -0,0 +1,123 @@ +"""Unit tests for results display wiring contracts.""" + +import ast +from pathlib import Path +import unittest + +try: + from .ast_test_utils import load_module_ast, subscript_key +except ImportError: # pragma: no cover - supports direct file execution + from ast_test_utils import load_module_ast, subscript_key + + +_WIRING_PATH = Path(__file__).with_name("results_display_wiring.py") + +_EXPECTED_RESTORE_OUTPUT_KEYS = [ + "text2music_audio_code_string", + "captions", + "lyrics", + "bpm", + "key_scale", + "time_signature", + "vocal_language", + "audio_duration", + "batch_size_input", + "inference_steps", + "lm_temperature", + "lm_cfg_scale", + "lm_top_k", + "lm_top_p", + "think_checkbox", + "use_cot_caption", + "use_cot_language", + "allow_lm_batch", + "track_name", + "complete_track_classes", + "enable_normalization", + "normalization_db", + "latent_shift", + "latent_rescale", +] + +_EXPECTED_JS_MARKERS = [ + "[Debug] Current Audio Input:", + "Warning: No audio selected or audio is empty.", + "Warning: Batch file list is empty/not ready.", + "Error: Could not extract a valid path string from input.", + "Key extracted:", + "Warning: No matching files found in batch list for key:", + "Found ${targets.length} files to download.", +] + +_FORBIDDEN_MOJIBAKE_MARKERS = ["ƃ", "ðŸ", "âŔ", "Ć¢Ā"] + +class ResultsDisplayWiringTests(unittest.TestCase): + """Verify save/download JS and restore/LRC wiring ordering contracts.""" + + def test_download_js_contains_expected_ascii_messages(self): + """Download JS should contain expected ASCII diagnostics and no mojibake markers.""" + + module = load_module_ast(_WIRING_PATH) + js_literal = None + for node in module.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "_DOWNLOAD_EXISTING_JS": + if isinstance(node.value, ast.Constant) and isinstance(node.value.value, str): + js_literal = node.value.value + break + if js_literal is not None: + break + self.assertIsNotNone(js_literal, "_DOWNLOAD_EXISTING_JS not found") + for marker in _EXPECTED_JS_MARKERS: + self.assertIn(marker, js_literal) + for marker in _FORBIDDEN_MOJIBAKE_MARKERS: + self.assertNotIn(marker, js_literal) + + def test_restore_outputs_keep_expected_order(self): + """Restore params click outputs should keep existing generation field ordering.""" + + module = load_module_ast(_WIRING_PATH) + for node in module.body: + if not isinstance(node, ast.FunctionDef) or node.name != "register_results_restore_and_lrc_handlers": + continue + for call in ast.walk(node): + if not isinstance(call, ast.Call): + continue + if not isinstance(call.func, ast.Attribute) or call.func.attr != "click": + continue + for keyword in call.keywords: + if keyword.arg != "outputs" or not isinstance(keyword.value, ast.List): + continue + keys = [] + for element in keyword.value.elts: + if isinstance(element, ast.Subscript): + key = subscript_key(element) + if key is not None: + keys.append(key) + if keys == _EXPECTED_RESTORE_OUTPUT_KEYS: + return + self.fail("restore_params_btn outputs contract not found") + + def test_save_and_lrc_handlers_cover_all_8_result_slots(self): + """Both save-btn and lrc-display loops should iterate over slots 1..8.""" + + module = load_module_ast(_WIRING_PATH) + range_calls = [] + for node in ast.walk(module): + if isinstance(node, ast.For) and isinstance(node.iter, ast.Call): + call = node.iter + if isinstance(call.func, ast.Name) and call.func.id == "range": + args = [] + for arg in call.args: + if isinstance(arg, ast.Constant) and isinstance(arg.value, int): + args.append(arg.value) + if args: + range_calls.append(tuple(args)) + + self.assertIn((1, 9), range_calls) + self.assertGreaterEqual(range_calls.count((1, 9)), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/ui/gradio/events/wiring/training_dataset_builder_wiring.py b/acestep/ui/gradio/events/wiring/training_dataset_builder_wiring.py new file mode 100644 index 00000000..d3b868fc --- /dev/null +++ b/acestep/ui/gradio/events/wiring/training_dataset_builder_wiring.py @@ -0,0 +1,176 @@ +"""Training dataset-builder event wiring helpers.""" + +from typing import Any, Mapping + +import gradio as gr + +from .. import training_handlers as train_h +from .context import TrainingWiringContext + + +_SAMPLE_PREVIEW_OUTPUT_KEYS = ( + "preview_audio", + "preview_filename", + "edit_caption", + "edit_genre", + "prompt_override", + "edit_lyrics", + "edit_bpm", + "edit_keyscale", + "edit_timesig", + "edit_duration", + "edit_language", + "edit_instrumental", + "raw_lyrics_display", + "has_raw_lyrics_state", +) + +_SETTINGS_TRIGGER_KEYS = ( + "custom_tag", + "tag_position", + "all_instrumental", + "genre_ratio", +) + +_CHECKMARK = "\u2705" + + +def _build_sample_preview_outputs(training_section: Mapping[str, Any]) -> list[Any]: + """Return ordered sample-preview outputs shared by preview refresh handlers.""" + + return [training_section[key] for key in _SAMPLE_PREVIEW_OUTPUT_KEYS] + + +def register_training_dataset_builder_handlers(context: TrainingWiringContext) -> None: + """Register dataset-builder handlers while preserving existing IO ordering.""" + + training_section = context.training_section + dit_handler = context.dit_handler + llm_handler = context.llm_handler + sample_preview_outputs = _build_sample_preview_outputs(training_section) + + training_section["scan_btn"].click( + fn=lambda directory, name, tag, pos, instr, state: train_h.scan_directory( + directory, name, tag, pos, instr, state + ), + inputs=[ + training_section["audio_directory"], + training_section["dataset_name"], + training_section["custom_tag"], + training_section["tag_position"], + training_section["all_instrumental"], + training_section["dataset_builder_state"], + ], + outputs=[ + training_section["audio_files_table"], + training_section["scan_status"], + training_section["sample_selector"], + training_section["dataset_builder_state"], + ], + ) + + training_section["auto_label_btn"].click( + fn=lambda state, skip, fmt_lyrics, trans_lyrics, only_unlab: train_h.auto_label_all( + dit_handler, llm_handler, state, skip, fmt_lyrics, trans_lyrics, only_unlab + ), + inputs=[ + training_section["dataset_builder_state"], + training_section["skip_metas"], + training_section["format_lyrics"], + training_section["transcribe_lyrics"], + training_section["only_unlabeled"], + ], + outputs=[ + training_section["audio_files_table"], + training_section["label_progress"], + training_section["dataset_builder_state"], + ], + ).then( + fn=train_h.get_sample_preview, + inputs=[ + training_section["sample_selector"], + training_section["dataset_builder_state"], + ], + outputs=sample_preview_outputs, + ).then( + fn=lambda status: f"{status or (_CHECKMARK + ' Auto-label complete.')}\n{_CHECKMARK} Preview refreshed.", + inputs=[training_section["label_progress"]], + outputs=[training_section["label_progress"]], + ).then( + fn=lambda has_raw: gr.update(visible=bool(has_raw)), + inputs=[training_section["has_raw_lyrics_state"]], + outputs=[training_section["raw_lyrics_display"]], + ) + + training_section["format_lyrics"].change( + fn=lambda fmt: gr.update(value=False) if fmt else gr.update(), + inputs=[training_section["format_lyrics"]], + outputs=[training_section["transcribe_lyrics"]], + ) + + training_section["transcribe_lyrics"].change( + fn=lambda trans: gr.update(value=False) if trans else gr.update(), + inputs=[training_section["transcribe_lyrics"]], + outputs=[training_section["format_lyrics"]], + ) + + training_section["sample_selector"].change( + fn=train_h.get_sample_preview, + inputs=[ + training_section["sample_selector"], + training_section["dataset_builder_state"], + ], + outputs=sample_preview_outputs, + ).then( + fn=lambda has_raw: gr.update(visible=has_raw), + inputs=[training_section["has_raw_lyrics_state"]], + outputs=[training_section["raw_lyrics_display"]], + ) + + training_section["save_edit_btn"].click( + fn=train_h.save_sample_edit, + inputs=[ + training_section["sample_selector"], + training_section["edit_caption"], + training_section["edit_genre"], + training_section["prompt_override"], + training_section["edit_lyrics"], + training_section["edit_bpm"], + training_section["edit_keyscale"], + training_section["edit_timesig"], + training_section["edit_language"], + training_section["edit_instrumental"], + training_section["dataset_builder_state"], + ], + outputs=[ + training_section["audio_files_table"], + training_section["edit_status"], + training_section["dataset_builder_state"], + ], + ) + + for trigger_key in _SETTINGS_TRIGGER_KEYS: + training_section[trigger_key].change( + fn=train_h.update_settings, + inputs=[ + training_section["custom_tag"], + training_section["tag_position"], + training_section["all_instrumental"], + training_section["genre_ratio"], + training_section["dataset_builder_state"], + ], + outputs=[training_section["dataset_builder_state"]], + ) + + training_section["save_dataset_btn"].click( + fn=train_h.save_dataset, + inputs=[ + training_section["save_path"], + training_section["dataset_name"], + training_section["dataset_builder_state"], + ], + outputs=[ + training_section["save_status"], + training_section["save_path"], + ], + ) diff --git a/acestep/ui/gradio/events/wiring/training_dataset_preprocess_wiring.py b/acestep/ui/gradio/events/wiring/training_dataset_preprocess_wiring.py new file mode 100644 index 00000000..4dbee4ab --- /dev/null +++ b/acestep/ui/gradio/events/wiring/training_dataset_preprocess_wiring.py @@ -0,0 +1,87 @@ +"""Training dataset-load and preprocess wiring helpers.""" + +from typing import Any, Mapping + +import gradio as gr + +from .. import training_handlers as train_h +from .context import TrainingWiringContext + + +_DATASET_LOAD_SHARED_OUTPUT_KEYS = ( + "audio_files_table", + "sample_selector", + "dataset_builder_state", + "preview_audio", + "preview_filename", + "edit_caption", + "edit_genre", + "prompt_override", + "edit_lyrics", + "edit_bpm", + "edit_keyscale", + "edit_timesig", + "edit_duration", + "edit_language", + "edit_instrumental", + "raw_lyrics_display", + "has_raw_lyrics_state", + "dataset_name", + "custom_tag", + "tag_position", + "all_instrumental", + "genre_ratio", +) + + +def _build_dataset_load_outputs( + training_section: Mapping[str, Any], + status_key: str, +) -> list[Any]: + """Return the ordered output list for dataset-load button wiring.""" + + return [training_section[status_key]] + [ + training_section[key] for key in _DATASET_LOAD_SHARED_OUTPUT_KEYS + ] + + +def register_training_dataset_load_handler( + context: TrainingWiringContext, + *, + button_key: str, + path_key: str, + status_key: str, +) -> None: + """Register one dataset JSON load button with shared output/update contracts.""" + + training_section = context.training_section + training_section[button_key].click( + fn=train_h.load_existing_dataset_for_preprocess, + inputs=[ + training_section[path_key], + training_section["dataset_builder_state"], + ], + outputs=_build_dataset_load_outputs(training_section, status_key), + ).then( + fn=lambda has_raw: gr.update(visible=has_raw), + inputs=[training_section["has_raw_lyrics_state"]], + outputs=[training_section["raw_lyrics_display"]], + ) + + +def register_training_preprocess_handler(context: TrainingWiringContext) -> None: + """Register preprocess button wiring for tensor conversion.""" + + training_section = context.training_section + dit_handler = context.dit_handler + training_section["preprocess_btn"].click( + fn=lambda output_dir, mode, state: train_h.preprocess_dataset( + output_dir, mode, dit_handler, state + ), + inputs=[ + training_section["preprocess_output_dir"], + training_section["preprocess_mode"], + training_section["dataset_builder_state"], + ], + outputs=[training_section["preprocess_progress"]], + ) diff --git a/acestep/ui/gradio/events/wiring/training_lokr_wiring.py b/acestep/ui/gradio/events/wiring/training_lokr_wiring.py new file mode 100644 index 00000000..30786936 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/training_lokr_wiring.py @@ -0,0 +1,143 @@ +"""LoKr-specific training run wiring helpers.""" + +from typing import Any, Callable, Iterator + +from loguru import logger + +from .. import training_handlers as train_h +from .context import TrainingWiringContext + + +def _build_lokr_training_wrapper( + dit_handler: Any, + normalize_training_state: Callable[[Any], dict[str, bool]], +): + """Build the LoKr training stream wrapper bound to the current DiT handler.""" + + def lokr_training_wrapper( + tensor_dir: Any, + lokr_linear_dim: Any, + lokr_linear_alpha: Any, + lokr_factor: Any, + lokr_decompose_both: Any, + lokr_use_tucker: Any, + lokr_use_scalar: Any, + lokr_weight_decompose: Any, + lokr_learning_rate: Any, + lokr_train_epochs: Any, + lokr_train_batch_size: Any, + lokr_gradient_accumulation: Any, + lokr_save_every_n_epochs: Any, + lokr_training_shift: Any, + lokr_training_seed: Any, + lokr_output_dir: Any, + training_state: Any, + ) -> Iterator[tuple[Any, Any, Any, dict[str, bool]]]: + """Stream LoKr training progress and normalize failure outputs for UI.""" + + state = normalize_training_state(training_state) + try: + for progress, log_msg, plot, next_state in train_h.start_lokr_training( + tensor_dir, + dit_handler, + lokr_linear_dim, + lokr_linear_alpha, + lokr_factor, + lokr_decompose_both, + lokr_use_tucker, + lokr_use_scalar, + lokr_weight_decompose, + lokr_learning_rate, + lokr_train_epochs, + lokr_train_batch_size, + lokr_gradient_accumulation, + lokr_save_every_n_epochs, + lokr_training_shift, + lokr_training_seed, + lokr_output_dir, + state, + ): + yield progress, log_msg, plot, next_state + except Exception as exc: # pragma: no cover - defensive UI wrapper + logger.exception("LoKr training wrapper error") + yield f"\u274c Error: {exc!r}", f"{exc!r}", None, state + + return lokr_training_wrapper + + +def register_lokr_training_handlers( + context: TrainingWiringContext, + *, + normalize_training_state: Callable[[Any], dict[str, bool]], +) -> None: + """Register LoKr training handlers with stable IO ordering.""" + + training_section = context.training_section + lokr_training_wrapper = _build_lokr_training_wrapper( + context.dit_handler, + normalize_training_state, + ) + + # ========== LoKr Training Tab Handlers ========== + training_section["lokr_load_dataset_btn"].click( + fn=train_h.load_training_dataset, + inputs=[training_section["lokr_training_tensor_dir"]], + outputs=[training_section["lokr_training_dataset_info"]], + ) + + training_section["start_lokr_training_btn"].click( + fn=lokr_training_wrapper, + inputs=[ + training_section["lokr_training_tensor_dir"], + training_section["lokr_linear_dim"], + training_section["lokr_linear_alpha"], + training_section["lokr_factor"], + training_section["lokr_decompose_both"], + training_section["lokr_use_tucker"], + training_section["lokr_use_scalar"], + training_section["lokr_weight_decompose"], + training_section["lokr_learning_rate"], + training_section["lokr_train_epochs"], + training_section["lokr_train_batch_size"], + training_section["lokr_gradient_accumulation"], + training_section["lokr_save_every_n_epochs"], + training_section["lokr_training_shift"], + training_section["lokr_training_seed"], + training_section["lokr_output_dir"], + training_section["training_state"], + ], + outputs=[ + training_section["lokr_training_progress"], + training_section["lokr_training_log"], + training_section["lokr_training_loss_plot"], + training_section["training_state"], + ], + ) + + training_section["stop_lokr_training_btn"].click( + fn=train_h.stop_training, + inputs=[training_section["training_state"]], + outputs=[ + training_section["lokr_training_progress"], + training_section["training_state"], + ], + ) + + training_section["refresh_lokr_export_epochs_btn"].click( + fn=train_h.list_lokr_export_epochs, + inputs=[training_section["lokr_output_dir"]], + outputs=[ + training_section["lokr_export_epoch"], + training_section["lokr_export_status"], + ], + ) + + training_section["export_lokr_btn"].click( + fn=train_h.export_lokr, + inputs=[ + training_section["lokr_export_path"], + training_section["lokr_output_dir"], + training_section["lokr_export_epoch"], + ], + outputs=[training_section["lokr_export_status"]], + ) diff --git a/acestep/ui/gradio/events/wiring/training_run_wiring.py b/acestep/ui/gradio/events/wiring/training_run_wiring.py new file mode 100644 index 00000000..bba36333 --- /dev/null +++ b/acestep/ui/gradio/events/wiring/training_run_wiring.py @@ -0,0 +1,128 @@ +"""Training run wiring helpers extracted from ``events.__init__``.""" + +from typing import Any, Iterator + +from loguru import logger + +from .. import training_handlers as train_h +from .context import TrainingWiringContext +from .training_lokr_wiring import register_lokr_training_handlers + + +def _normalize_training_state(training_state: Any) -> dict[str, bool]: + """Return a valid mutable training-state mapping for streaming wrappers.""" + + if isinstance(training_state, dict): + return training_state + return {"is_training": False, "should_stop": False} + + +def _build_training_wrapper(dit_handler: Any): + """Build the training stream wrapper bound to the current DiT handler.""" + + def training_wrapper( + tensor_dir: Any, + lora_rank: Any, + lora_alpha: Any, + lora_dropout: Any, + learning_rate: Any, + train_epochs: Any, + train_batch_size: Any, + gradient_accumulation: Any, + save_every_n_epochs: Any, + training_shift: Any, + training_seed: Any, + lora_output_dir: Any, + resume_checkpoint_dir: Any, + training_state: Any, + ) -> Iterator[tuple[Any, Any, Any, dict[str, bool]]]: + """Stream LoRA training progress and normalize failure outputs for UI.""" + + state = _normalize_training_state(training_state) + try: + for progress, log_msg, plot, next_state in train_h.start_training( + tensor_dir, + dit_handler, + lora_rank, + lora_alpha, + lora_dropout, + learning_rate, + train_epochs, + train_batch_size, + gradient_accumulation, + save_every_n_epochs, + training_shift, + training_seed, + lora_output_dir, + resume_checkpoint_dir, + state, + ): + yield progress, log_msg, plot, next_state + except Exception as exc: # pragma: no cover - defensive UI wrapper + logger.exception("Training wrapper error") + yield f"\u274c Error: {exc!s}", f"{exc!s}", None, state + + return training_wrapper + + +def register_training_run_handlers(context: TrainingWiringContext) -> None: + """Register training run-tab handlers with stable IO ordering.""" + + training_section = context.training_section + training_wrapper = _build_training_wrapper(context.dit_handler) + + # ========== Training Tab Handlers ========== + training_section["load_dataset_btn"].click( + fn=train_h.load_training_dataset, + inputs=[training_section["training_tensor_dir"]], + outputs=[training_section["training_dataset_info"]], + ) + + training_section["start_training_btn"].click( + fn=training_wrapper, + inputs=[ + training_section["training_tensor_dir"], + training_section["lora_rank"], + training_section["lora_alpha"], + training_section["lora_dropout"], + training_section["learning_rate"], + training_section["train_epochs"], + training_section["train_batch_size"], + training_section["gradient_accumulation"], + training_section["save_every_n_epochs"], + training_section["training_shift"], + training_section["training_seed"], + training_section["lora_output_dir"], + training_section["resume_checkpoint_dir"], + training_section["training_state"], + ], + outputs=[ + training_section["training_progress"], + training_section["training_log"], + training_section["training_loss_plot"], + training_section["training_state"], + ], + ) + + training_section["stop_training_btn"].click( + fn=train_h.stop_training, + inputs=[training_section["training_state"]], + outputs=[ + training_section["training_progress"], + training_section["training_state"], + ], + ) + + training_section["export_lora_btn"].click( + fn=train_h.export_lora, + inputs=[ + training_section["export_path"], + training_section["lora_output_dir"], + ], + outputs=[training_section["export_status"]], + ) + + register_lokr_training_handlers( + context, + normalize_training_state=_normalize_training_state, + )