Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 74 additions & 25 deletions acestep/core/generation/handler/lora/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,31 @@
from acestep.debug_utils import debug_log


def _toggle_lokr(decoder, enable: bool, scale: float = 1.0) -> bool:
"""Toggle a LyCORIS LoKr adapter via its multiplier.

Args:
decoder: Model decoder that may carry a ``_lycoris_net`` attribute.
enable: ``True`` to activate (restore multiplier), ``False`` to zero it.
scale: Multiplier value when enabling (default 1.0).

Returns:
``True`` if a LyCORIS net was found and toggled, ``False`` otherwise.
"""
lycoris_net = getattr(decoder, "_lycoris_net", None)
if lycoris_net is None:
return False
set_mul = getattr(lycoris_net, "set_multiplier", None)
if not callable(set_mul):
return False
target = float(scale) if enable else 0.0
set_mul(target)
logger.info(f"LoKr multiplier set to {target}")
return True


def set_use_lora(self, use_lora: bool) -> str:
"""Toggle LoRA usage for inference."""
"""Toggle LoRA/LoKr usage for inference."""
if use_lora and not self.lora_loaded:
return "❌ No LoRA adapter loaded. Please load a LoRA first."

Expand All @@ -20,28 +43,41 @@ def set_use_lora(self, use_lora: bool) -> str:
if self.lora_loaded and decoder is None:
logger.warning("LoRA is marked as loaded, but model/decoder is unavailable during toggle.")

if self.lora_loaded and decoder is not None and hasattr(decoder, "disable_adapter_layers"):
try:
if use_lora:
active = getattr(self, "_lora_active_adapter", None)
if active and hasattr(decoder, "set_adapter"):
try:
decoder.set_adapter(active)
except Exception:
pass
decoder.enable_adapter_layers()
logger.info("LoRA adapter enabled")
scale = getattr(self, "_active_loras", {}).get(active, 1.0)
if active and scale != 1.0:
self.set_lora_scale(active, scale)
else:
decoder.disable_adapter_layers()
logger.info("LoRA adapter disabled")
except Exception as e:
logger.warning(f"Could not toggle adapter layers: {e}")

if self.lora_loaded and decoder is not None:
adapter_type = getattr(self, "_adapter_type", None)

# LoKr (LyCORIS) path: toggle via set_multiplier
if adapter_type == "lokr":
active = getattr(self, "_lora_active_adapter", None)
scale = getattr(self, "_active_loras", {}).get(active, 1.0) if active else self.lora_scale
toggled = _toggle_lokr(decoder, use_lora, scale=scale)
if not toggled:
logger.warning("LoKr adapter type set but no _lycoris_net found on decoder")

# PEFT LoRA path: toggle via enable/disable adapter layers
elif hasattr(decoder, "disable_adapter_layers"):
try:
if use_lora:
active = getattr(self, "_lora_active_adapter", None)
if active and hasattr(decoder, "set_adapter"):
try:
decoder.set_adapter(active)
except Exception:
pass
decoder.enable_adapter_layers()
logger.info("LoRA adapter enabled")
scale = getattr(self, "_active_loras", {}).get(active, 1.0)
if active and scale != 1.0:
self.set_lora_scale(active, scale)
else:
decoder.disable_adapter_layers()
logger.info("LoRA adapter disabled")
except Exception as e:
logger.warning(f"Could not toggle adapter layers: {e}")

adapter_label = "LoKr" if getattr(self, "_adapter_type", None) == "lokr" else "LoRA"
status = "enabled" if use_lora else "disabled"
return f"✅ LoRA {status}"
return f"✅ {adapter_label} {status}"


def set_lora_scale(self, adapter_name_or_scale: str | float, scale: float | None = None) -> str:
Expand Down Expand Up @@ -76,10 +112,23 @@ def set_lora_scale(self, adapter_name_or_scale: str | float, scale: float | None
self._active_loras[effective_name] = scale_value
self.lora_scale = scale_value # backward compat: single "current" scale for status/UI

if not self.use_lora:
logger.info(f"LoRA scale for '{effective_name}' set to {scale_value:.2f} (will apply when LoRA is enabled)")
return f"✅ LoRA scale ({effective_name}): {scale_value:.2f} (LoRA disabled)"
adapter_label = "LoKr" if getattr(self, "_adapter_type", None) == "lokr" else "LoRA"

if not self.use_lora:
logger.info(f"{adapter_label} scale for '{effective_name}' set to {scale_value:.2f} (will apply when enabled)")
return f"✅ {adapter_label} scale ({effective_name}): {scale_value:.2f} ({adapter_label} disabled)"

# LoKr (LyCORIS) path: apply scale via set_multiplier
if getattr(self, "_adapter_type", None) == "lokr":
decoder = getattr(getattr(self, "model", None), "decoder", None)
if decoder is not None:
toggled = _toggle_lokr(decoder, True, scale=scale_value)
if toggled:
return f"✅ {adapter_label} scale ({effective_name}): {scale_value:.2f}"
logger.warning("LoKr adapter type set but no _lycoris_net found for scale")
return f"⚠️ {adapter_label} scale set to {scale_value:.2f} (no LyCORIS net found)"

# PEFT LoRA path: apply scale via registry
try:
rebuilt_adapters: list[str] | None = None
if not getattr(self, "_lora_adapter_registry", None):
Expand Down
197 changes: 197 additions & 0 deletions acestep/core/generation/handler/lora/controls_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Tests for LoRA/LoKr runtime controls (toggle, scale)."""

import unittest
from types import SimpleNamespace
from unittest.mock import Mock

from acestep.core.generation.handler.lora.controls import (
_toggle_lokr,
set_lora_scale,
set_use_lora,
)


class _DummyHandler:
"""Handler stub exposing the attributes used by ``set_use_lora`` / ``set_lora_scale``."""

def __init__(self, adapter_type=None) -> None:
self.model = SimpleNamespace(decoder=SimpleNamespace())
self.lora_loaded = True
self.use_lora = True
self.lora_scale = 1.0
self._adapter_type = adapter_type
self._lora_active_adapter = "default"
self._active_loras = {"default": 1.0}
self._lora_service = SimpleNamespace(
registry={"default": {}},
scale_state={},
active_adapter="default",
last_scale_report={},
synthetic_default_mode=False,
)

def _ensure_lora_registry(self):
return None

def _rebuild_lora_registry(self, lora_path=None):
return 0, list(self._active_loras.keys())

def _sync_lora_state_from_service(self):
return None

def _apply_scale_to_adapter(self, name, scale):
return 1

def set_lora_scale(self, adapter_name_or_scale, scale=None):
return set_lora_scale(self, adapter_name_or_scale, scale)


class ToggleLokrTests(unittest.TestCase):
"""Unit tests for the ``_toggle_lokr`` helper."""

def test_disable_sets_multiplier_to_zero(self):
"""Disabling should call set_multiplier(0.0)."""
lycoris_net = SimpleNamespace(set_multiplier=Mock())
decoder = SimpleNamespace(_lycoris_net=lycoris_net)
result = _toggle_lokr(decoder, enable=False)
self.assertTrue(result)
lycoris_net.set_multiplier.assert_called_once_with(0.0)

def test_enable_sets_multiplier_to_scale(self):
"""Enabling should call set_multiplier with the given scale."""
lycoris_net = SimpleNamespace(set_multiplier=Mock())
decoder = SimpleNamespace(_lycoris_net=lycoris_net)
result = _toggle_lokr(decoder, enable=True, scale=0.75)
self.assertTrue(result)
lycoris_net.set_multiplier.assert_called_once_with(0.75)

def test_returns_false_when_no_lycoris_net(self):
"""Should return False when decoder has no _lycoris_net."""
decoder = SimpleNamespace()
result = _toggle_lokr(decoder, enable=False)
self.assertFalse(result)

def test_returns_false_when_no_set_multiplier(self):
"""Should return False when _lycoris_net lacks set_multiplier."""
decoder = SimpleNamespace(_lycoris_net=SimpleNamespace())
result = _toggle_lokr(decoder, enable=True)
self.assertFalse(result)


class SetUseLokrTests(unittest.TestCase):
"""Tests for set_use_lora with LoKr adapter type."""

def test_disable_lokr_zeros_multiplier(self):
"""Unchecking use_lora should set LoKr multiplier to 0."""
handler = _DummyHandler(adapter_type="lokr")
lycoris_net = SimpleNamespace(set_multiplier=Mock())
handler.model.decoder._lycoris_net = lycoris_net

result = set_use_lora(handler, False)

self.assertFalse(handler.use_lora)
lycoris_net.set_multiplier.assert_called_once_with(0.0)
self.assertIn("LoKr", result)
self.assertIn("disabled", result)

def test_enable_lokr_restores_multiplier(self):
"""Re-checking use_lora should restore LoKr multiplier to saved scale."""
handler = _DummyHandler(adapter_type="lokr")
handler.use_lora = False
handler._active_loras = {"default": 0.8}
lycoris_net = SimpleNamespace(set_multiplier=Mock())
handler.model.decoder._lycoris_net = lycoris_net

result = set_use_lora(handler, True)

self.assertTrue(handler.use_lora)
lycoris_net.set_multiplier.assert_called_once_with(0.8)
self.assertIn("LoKr", result)
self.assertIn("enabled", result)

def test_enable_lokr_uses_lora_scale_fallback(self):
"""When no active adapter, should fall back to self.lora_scale."""
handler = _DummyHandler(adapter_type="lokr")
handler.use_lora = False
handler._lora_active_adapter = None
handler.lora_scale = 0.5
lycoris_net = SimpleNamespace(set_multiplier=Mock())
handler.model.decoder._lycoris_net = lycoris_net

set_use_lora(handler, True)

lycoris_net.set_multiplier.assert_called_once_with(0.5)


class SetUsePeftLoraTests(unittest.TestCase):
"""Tests for set_use_lora with PEFT LoRA adapter type (non-regression)."""

def test_disable_peft_lora_calls_disable_adapter_layers(self):
"""Unchecking use_lora should call disable_adapter_layers for PEFT."""
handler = _DummyHandler(adapter_type="lora")
handler.model.decoder.disable_adapter_layers = Mock()

result = set_use_lora(handler, False)

self.assertFalse(handler.use_lora)
handler.model.decoder.disable_adapter_layers.assert_called_once()
self.assertIn("LoRA", result)
self.assertIn("disabled", result)

def test_enable_peft_lora_calls_enable_adapter_layers(self):
"""Re-checking use_lora should call enable_adapter_layers for PEFT."""
handler = _DummyHandler(adapter_type="lora")
handler.use_lora = False
handler.model.decoder.enable_adapter_layers = Mock()
handler.model.decoder.disable_adapter_layers = Mock()
handler.model.decoder.set_adapter = Mock()

result = set_use_lora(handler, True)

self.assertTrue(handler.use_lora)
handler.model.decoder.enable_adapter_layers.assert_called_once()
self.assertIn("LoRA", result)
self.assertIn("enabled", result)

def test_no_adapter_loaded_returns_error(self):
"""Enabling with no adapter loaded should return error."""
handler = _DummyHandler()
handler.lora_loaded = False
handler.use_lora = False

result = set_use_lora(handler, True)
self.assertIn("❌", result)


class SetLokrScaleTests(unittest.TestCase):
"""Tests for set_lora_scale with LoKr adapter type."""

def test_scale_lokr_sets_multiplier(self):
"""Setting scale on LoKr should call set_multiplier with the value."""
handler = _DummyHandler(adapter_type="lokr")
lycoris_net = SimpleNamespace(set_multiplier=Mock())
handler.model.decoder._lycoris_net = lycoris_net

result = set_lora_scale(handler, 0.6)

lycoris_net.set_multiplier.assert_called_once_with(0.6)
self.assertIn("0.60", result)
self.assertIn("LoKr", result)
self.assertAlmostEqual(handler.lora_scale, 0.6)

def test_scale_lokr_when_disabled_stores_but_does_not_apply(self):
"""Scale change while disabled should store value but not call multiplier."""
handler = _DummyHandler(adapter_type="lokr")
handler.use_lora = False
lycoris_net = SimpleNamespace(set_multiplier=Mock())
handler.model.decoder._lycoris_net = lycoris_net

result = set_lora_scale(handler, 0.3)

lycoris_net.set_multiplier.assert_not_called()
self.assertIn("disabled", result)
self.assertAlmostEqual(handler.lora_scale, 0.3)


if __name__ == "__main__":
unittest.main()
Loading