Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions emerging_optimizers/riemannian_optimizers/normalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
import torch
from torch.optim.optimizer import Optimizer

from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import registry


__all__ = ["ObliqueSGD", "ObliqueAdam"]


@registry.register_optimizer("oblique_sgd")
class ObliqueSGD(Optimizer):
class ObliqueSGD(opt_mixin.WeightDecayMixin, Optimizer):
"""SGD optimizer for row- or column-normalized 2D parameters on oblique manifolds.

This optimizer performs SGD on oblique manifolds, where parameters are constrained
Expand All @@ -46,6 +47,7 @@ class ObliqueSGD(Optimizer):
lr: learning rate
momentum: momentum coefficient
weight_decay: weight decay coefficient
weight_decay_method: Method to apply weight decay.
dim: The dimension to normalize over
eps: epsilon for numerical stability
"""
Expand All @@ -56,6 +58,7 @@ def __init__(
lr: float = 1e-3,
momentum: float = 0.9,
weight_decay: float = 0.0,
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
dim: int = 0,
eps: float = 1e-8,
) -> None:
Expand All @@ -73,6 +76,7 @@ def __init__(
dim=dim,
eps=eps,
)
self.weight_decay_method = weight_decay_method
super().__init__(params, defaults)

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,10 +119,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
buf = state["momentum_buffer"]

# theory style momentum
buf = torch.add(grad, buf, alpha=mom)
torch.add(grad, buf, alpha=mom, out=buf)

# Apply Riemannian gradient update
_compute_riemannian_grad_and_update(param, buf, dim, lr, wd)
riem_grad = _compute_riemannian_grad(param, buf, dim)

self._apply_weight_decay_inplace(param, riem_grad, lr, wd)
param.add_(riem_grad, alpha=-lr)

# Retraction back to the manifold, the hyper-sphere
torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param)
Expand All @@ -127,7 +133,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:


@registry.register_optimizer("oblique_adam")
class ObliqueAdam(Optimizer):
class ObliqueAdam(opt_mixin.WeightDecayMixin, Optimizer):
"""Adam optimizer for row- or column-normalized 2D parameters on oblique manifolds.

This optimizer adapts an Adam-like algorithm to work on oblique manifolds, where
Expand All @@ -141,6 +147,7 @@ def __init__(
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
dim: int = 0,
eps: float = 1e-8,
correct_bias: bool = True,
Expand All @@ -151,6 +158,7 @@ def __init__(
lr: The learning rate.
betas: The coefficients used for computing running averages of gradient and its square.
weight_decay: The weight decay coefficient.
weight_decay_method: Method to apply weight decay.
dim: The dimension to normalize over.
eps: The epsilon for numerical stability.
correct_bias: Whether to correct bias in Adam-like computation.
Expand All @@ -172,6 +180,7 @@ def __init__(
eps=eps,
correct_bias=correct_bias,
)
self.weight_decay_method = weight_decay_method
super().__init__(params, defaults)

if TYPE_CHECKING:
Expand Down Expand Up @@ -239,33 +248,28 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:

norm_grad = (exp_avg / bias_correction1) / (exp_avg_sq.sqrt() / bias_correction2 + eps)

# Apply Riemannian gradient update
_compute_riemannian_grad_and_update(param, norm_grad, dim, lr, wd)
riem_grad = _compute_riemannian_grad(param, norm_grad, dim)

self._apply_weight_decay_inplace(param, riem_grad, lr, wd)
param.add_(riem_grad, alpha=-lr)

# Retraction back to the manifold, i.e. the hyper-sphere
torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param)

return loss


def _compute_riemannian_grad_and_update(
param: torch.Tensor, grad_like: torch.Tensor, dim: int, lr: float, wd: float
) -> None:
"""Compute Riemannian gradient for oblique manifold and update parameter in-place.
def _compute_riemannian_grad(param: torch.Tensor, grad_like: torch.Tensor, dim: int) -> torch.Tensor:
"""Compute the Riemannian gradient for the oblique manifold.

Args:
param: Parameter tensor (2D)
grad_like: Gradient-like tensor (momentum buffer or normalized gradient)
grad_like: Gradient-like tensor (momentum buffer or gradient)
dim: The dimension to normalize over
lr: Learning rate
wd: Weight decay coefficient

Returns:
The tangent-space projected gradient.
"""

inner = (param * grad_like).sum(dim=dim, keepdim=True)
riem_grad = torch.add(grad_like, param * inner, alpha=-1)

# Add decoupled weight decay
param.mul_(1 - lr * wd)

# Apply update in-place
param.add_(riem_grad, alpha=-lr)
return torch.add(grad_like, param * inner, alpha=-1)
2 changes: 1 addition & 1 deletion tests/convergence/normalized_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_optimizer_modes_convergence(self, optimizer_class: torch.optim.Optimize
# Basic convergence check
self.assertLess(final_loss, initial_loss * 1.01, "Loss should decrease or stay stable")
print(f"Final accuracy: {final_accuracy}")
self.assertGreater(final_accuracy, 50.0, "Should achieve reasonable accuracy")
self.assertGreater(final_accuracy, 40.0, "Should achieve reasonable accuracy")

# Verify norm preservation based on mode
for param in model.get_oblique_parameters():
Expand Down
49 changes: 47 additions & 2 deletions tests/test_normalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,47 @@ def test_oblique_sgd_zero_gradient(self) -> None:
expected_norms = torch.ones_like(final_norms)
torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6)

def test_oblique_sgd_momentum_buffer_accumulates_across_steps(self) -> None:
"""Test that ObliqueSGD persists momentum state across optimization steps."""
param = torch.tensor(
[[1.0, 0.0], [0.0, 1.0]],
dtype=torch.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dtype=torch.float32 is not necessary. default dtype is almost never changed.

device=self.device,
)
param = torch.nn.Parameter(param)
optimizer = ObliqueSGD([param], lr=0.1, momentum=0.8, dim=0)

first_grad = torch.tensor(
[[1.0, 2.0], [3.0, 4.0]],
dtype=torch.float32,
device=self.device,
)
second_grad = torch.tensor(
[[0.5, 1.5], [2.5, 3.5]],
dtype=torch.float32,
device=self.device,
)

param.grad = first_grad.clone()
optimizer.step()
torch.testing.assert_close(
optimizer.state[param]["momentum_buffer"],
first_grad,
atol=0,
rtol=0,
)

param.grad = second_grad.clone()
optimizer.step()

expected_buffer = second_grad + 0.8 * first_grad
torch.testing.assert_close(
optimizer.state[param]["momentum_buffer"],
expected_buffer,
atol=0,
rtol=0,
)
Comment on lines +178 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Strict zero-tolerance assertion may be fragile on second step

The second momentum-buffer assertion uses atol=0, rtol=0 against expected_buffer = second_grad + 0.8 * first_grad. The optimizer computes buf.mul_(0.8).add_(second_grad) while the expected value is computed as second_grad + (0.8 * first_grad) — two different Python/PyTorch expressions. Float32 addition is commutative (a + b == b + a) so the values are identical in this specific case, but the ordering of operations differs and could diverge on other hardware/precision modes.

Consider using a small tolerance to make the test more robust:

Suggested change
torch.testing.assert_close(
optimizer.state[param]["momentum_buffer"],
expected_buffer,
atol=0,
rtol=0,
)
torch.testing.assert_close(
optimizer.state[param]["momentum_buffer"],
expected_buffer,
atol=1e-6,
rtol=1e-6,
)


def test_oblique_adam_zero_gradient(self) -> None:
"""Test that ObliqueAdam handles zero gradients correctly."""
matrix_size = (2, 3)
Expand Down Expand Up @@ -217,7 +258,11 @@ def test_oblique_adam_large_gradient(self) -> None:
rtol=1e-6,
)

def test_multiple_optimization_steps_preserve_norms(self) -> None:
@parameterized.parameters(
(0.4),
(0.8),
)
def test_multiple_optimization_steps_preserve_norms(self, momentum: float) -> None:
"""Test that norms are preserved across multiple optimization steps."""
matrix_size = (4, 4)
param = torch.randn(matrix_size, dtype=torch.float32, device=self.device)
Expand All @@ -226,7 +271,7 @@ def test_multiple_optimization_steps_preserve_norms(self) -> None:
param = param / param.norm(dim=0, keepdim=True).clamp(min=1e-8)

param = torch.nn.Parameter(param)
optimizer = ObliqueSGD([param], lr=0.05, momentum=0.8, dim=0)
optimizer = ObliqueSGD([param], lr=0.05, momentum=momentum, dim=0)

# Perform multiple optimization steps
for step in range(10):
Expand Down
Loading