diff --git a/emerging_optimizers/riemannian_optimizers/normalized_optimizer.py b/emerging_optimizers/riemannian_optimizers/normalized_optimizer.py index fc30f117..7441d6ff 100644 --- a/emerging_optimizers/riemannian_optimizers/normalized_optimizer.py +++ b/emerging_optimizers/riemannian_optimizers/normalized_optimizer.py @@ -21,6 +21,7 @@ import torch from torch.optim.optimizer import Optimizer +from emerging_optimizers import mixin as opt_mixin from emerging_optimizers import registry @@ -28,7 +29,7 @@ @registry.register_optimizer("oblique_sgd") -class ObliqueSGD(Optimizer): +class ObliqueSGD(opt_mixin.WeightDecayMixin, Optimizer): """SGD optimizer for row- or column-normalized 2D parameters on oblique manifolds. This optimizer performs SGD on oblique manifolds, where parameters are constrained @@ -46,6 +47,7 @@ class ObliqueSGD(Optimizer): lr: learning rate momentum: momentum coefficient weight_decay: weight decay coefficient + weight_decay_method: Method to apply weight decay. dim: The dimension to normalize over eps: epsilon for numerical stability """ @@ -56,6 +58,7 @@ def __init__( lr: float = 1e-3, momentum: float = 0.9, weight_decay: float = 0.0, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", dim: int = 0, eps: float = 1e-8, ) -> None: @@ -73,6 +76,7 @@ def __init__( dim=dim, eps=eps, ) + self.weight_decay_method = weight_decay_method super().__init__(params, defaults) if TYPE_CHECKING: @@ -115,10 +119,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: buf = state["momentum_buffer"] # theory style momentum - buf = torch.add(grad, buf, alpha=mom) + torch.add(grad, buf, alpha=mom, out=buf) - # Apply Riemannian gradient update - _compute_riemannian_grad_and_update(param, buf, dim, lr, wd) + riem_grad = _compute_riemannian_grad(param, buf, dim) + + self._apply_weight_decay_inplace(param, riem_grad, lr, wd) + param.add_(riem_grad, alpha=-lr) # Retraction back to the manifold, the hyper-sphere torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) @@ -127,7 +133,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: @registry.register_optimizer("oblique_adam") -class ObliqueAdam(Optimizer): +class ObliqueAdam(opt_mixin.WeightDecayMixin, Optimizer): """Adam optimizer for row- or column-normalized 2D parameters on oblique manifolds. This optimizer adapts an Adam-like algorithm to work on oblique manifolds, where @@ -141,6 +147,7 @@ def __init__( lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", dim: int = 0, eps: float = 1e-8, correct_bias: bool = True, @@ -151,6 +158,7 @@ def __init__( lr: The learning rate. betas: The coefficients used for computing running averages of gradient and its square. weight_decay: The weight decay coefficient. + weight_decay_method: Method to apply weight decay. dim: The dimension to normalize over. eps: The epsilon for numerical stability. correct_bias: Whether to correct bias in Adam-like computation. @@ -172,6 +180,7 @@ def __init__( eps=eps, correct_bias=correct_bias, ) + self.weight_decay_method = weight_decay_method super().__init__(params, defaults) if TYPE_CHECKING: @@ -239,8 +248,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: norm_grad = (exp_avg / bias_correction1) / (exp_avg_sq.sqrt() / bias_correction2 + eps) - # Apply Riemannian gradient update - _compute_riemannian_grad_and_update(param, norm_grad, dim, lr, wd) + riem_grad = _compute_riemannian_grad(param, norm_grad, dim) + + self._apply_weight_decay_inplace(param, riem_grad, lr, wd) + param.add_(riem_grad, alpha=-lr) # Retraction back to the manifold, i.e. the hyper-sphere torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) @@ -248,24 +259,17 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: return loss -def _compute_riemannian_grad_and_update( - param: torch.Tensor, grad_like: torch.Tensor, dim: int, lr: float, wd: float -) -> None: - """Compute Riemannian gradient for oblique manifold and update parameter in-place. +def _compute_riemannian_grad(param: torch.Tensor, grad_like: torch.Tensor, dim: int) -> torch.Tensor: + """Compute the Riemannian gradient for the oblique manifold. Args: param: Parameter tensor (2D) - grad_like: Gradient-like tensor (momentum buffer or normalized gradient) + grad_like: Gradient-like tensor (momentum buffer or gradient) dim: The dimension to normalize over - lr: Learning rate - wd: Weight decay coefficient + + Returns: + The tangent-space projected gradient. """ inner = (param * grad_like).sum(dim=dim, keepdim=True) - riem_grad = torch.add(grad_like, param * inner, alpha=-1) - - # Add decoupled weight decay - param.mul_(1 - lr * wd) - - # Apply update in-place - param.add_(riem_grad, alpha=-lr) + return torch.add(grad_like, param * inner, alpha=-1) diff --git a/tests/convergence/normalized_optimizer_test.py b/tests/convergence/normalized_optimizer_test.py index 3bc08df8..d1a68066 100644 --- a/tests/convergence/normalized_optimizer_test.py +++ b/tests/convergence/normalized_optimizer_test.py @@ -211,7 +211,7 @@ def test_optimizer_modes_convergence(self, optimizer_class: torch.optim.Optimize # Basic convergence check self.assertLess(final_loss, initial_loss * 1.01, "Loss should decrease or stay stable") print(f"Final accuracy: {final_accuracy}") - self.assertGreater(final_accuracy, 50.0, "Should achieve reasonable accuracy") + self.assertGreater(final_accuracy, 40.0, "Should achieve reasonable accuracy") # Verify norm preservation based on mode for param in model.get_oblique_parameters(): diff --git a/tests/test_normalized_optimizer.py b/tests/test_normalized_optimizer.py index fe3c39fe..85f19ec0 100644 --- a/tests/test_normalized_optimizer.py +++ b/tests/test_normalized_optimizer.py @@ -141,6 +141,47 @@ def test_oblique_sgd_zero_gradient(self) -> None: expected_norms = torch.ones_like(final_norms) torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6) + def test_oblique_sgd_momentum_buffer_accumulates_across_steps(self) -> None: + """Test that ObliqueSGD persists momentum state across optimization steps.""" + param = torch.tensor( + [[1.0, 0.0], [0.0, 1.0]], + dtype=torch.float32, + device=self.device, + ) + param = torch.nn.Parameter(param) + optimizer = ObliqueSGD([param], lr=0.1, momentum=0.8, dim=0) + + first_grad = torch.tensor( + [[1.0, 2.0], [3.0, 4.0]], + dtype=torch.float32, + device=self.device, + ) + second_grad = torch.tensor( + [[0.5, 1.5], [2.5, 3.5]], + dtype=torch.float32, + device=self.device, + ) + + param.grad = first_grad.clone() + optimizer.step() + torch.testing.assert_close( + optimizer.state[param]["momentum_buffer"], + first_grad, + atol=0, + rtol=0, + ) + + param.grad = second_grad.clone() + optimizer.step() + + expected_buffer = second_grad + 0.8 * first_grad + torch.testing.assert_close( + optimizer.state[param]["momentum_buffer"], + expected_buffer, + atol=0, + rtol=0, + ) + def test_oblique_adam_zero_gradient(self) -> None: """Test that ObliqueAdam handles zero gradients correctly.""" matrix_size = (2, 3) @@ -217,7 +258,11 @@ def test_oblique_adam_large_gradient(self) -> None: rtol=1e-6, ) - def test_multiple_optimization_steps_preserve_norms(self) -> None: + @parameterized.parameters( + (0.4), + (0.8), + ) + def test_multiple_optimization_steps_preserve_norms(self, momentum: float) -> None: """Test that norms are preserved across multiple optimization steps.""" matrix_size = (4, 4) param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) @@ -226,7 +271,7 @@ def test_multiple_optimization_steps_preserve_norms(self) -> None: param = param / param.norm(dim=0, keepdim=True).clamp(min=1e-8) param = torch.nn.Parameter(param) - optimizer = ObliqueSGD([param], lr=0.05, momentum=0.8, dim=0) + optimizer = ObliqueSGD([param], lr=0.05, momentum=momentum, dim=0) # Perform multiple optimization steps for step in range(10):