diff --git a/docs/apidocs/orthogonalized-optimizers.md b/docs/apidocs/orthogonalized-optimizers.md index 051bcfe7..afbf5294 100644 --- a/docs/apidocs/orthogonalized-optimizers.md +++ b/docs/apidocs/orthogonalized-optimizers.md @@ -39,6 +39,12 @@ emerging_optimizers.orthogonalized_optimizers .. autoclass:: MuonHyperball :members: +:hidden:`Spectron` +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Spectron + :members: + :hidden:`Newton-Schulz` ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 7e8ddc4d..3ae4ac5d 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -19,3 +19,4 @@ from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.scion import * from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import * +from emerging_optimizers.orthogonalized_optimizers.spectron import * diff --git a/emerging_optimizers/orthogonalized_optimizers/spectron.py b/emerging_optimizers/orthogonalized_optimizers/spectron.py new file mode 100644 index 00000000..c4915060 --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/spectron.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, overload, override + +import torch +import torch.optim as optim +from absl import logging +from torch.optim.optimizer import ParamsT + +from emerging_optimizers import mixin as opt_mixin +from emerging_optimizers import registry, utils +from emerging_optimizers.orthogonalized_optimizers import muon_utils +from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT +from emerging_optimizers.utils import FP32MatmulPrecT +from emerging_optimizers.utils.eig import power_iteration + + +__all__ = ["Spectron"] + + +@registry.register_optimizer("spectron") +class Spectron(opt_mixin.WeightDecayMixin, optim.Optimizer): + """Spectron: Low-rank spectral optimizer with orthogonalized momentum. + + Spectron maintains each 2D weight matrix W as a low-rank factorization W = A @ B^T, + where A ∈ R^(m×r) and B ∈ R^(n×r). It applies momentum, orthogonalizes the updates + using Newton-Schulz iteration, and scales the learning rate by the spectral radii + of both factors. + + The algorithm: + 1. Compute gradients with respect to A and B from parameter gradients + 2. Apply momentum to both factors + 3. Orthogonalize momentum buffers using Newton-Schulz iteration + 4. Estimate spectral radius of A and B using power iteration + 5. Update with scaled learning rate: η / (σ_A + σ_B + 1) + 6. Reconstruct full weight matrix W = A @ B^T + + References: + - Algorithm 1 (Spectron) and Algorithm 3 (PowerIter) from the Spectron paper (https://arxiv.org/abs/2602.12429). + Low-rank spectral optimization with orthogonalized momentum. + + Warning: + - This optimizer requires that all parameters passed in are 2D. + - Low-rank factorization may not be suitable for all parameter types. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups + lr: The learning rate (η in the algorithm). Default: 3e-4 + rank: The rank of the low-rank factorization. Default: 64 + momentum_beta: The momentum decay coefficient (β). Default: 0.9 + weight_decay: The weight decay coefficient. Default: 0.01 + weight_decay_method: Method to apply weight decay. Default: "decoupled" + fp32_matmul_prec: Precision of matmul operations. Default: "medium" + num_ns_steps: Number of Newton-Schulz iteration steps. Default: 5 + num_power_iter: Number of power iteration steps for spectral radius. Default: 1 + coefficient_type: Type of coefficient set for Newton-Schulz. Default: "quintic" + """ + + def __init__( + self, + params: ParamsT, + lr: float = 3e-4, + rank: int = 64, + momentum_beta: float = 0.9, + weight_decay: float = 0.01, + *, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", + fp32_matmul_prec: FP32MatmulPrecT = "medium", + num_ns_steps: int = 5, + num_power_iter: int = 1, + coefficient_type: NSCoeffT = "quintic", + ) -> None: + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if rank < 1: + raise ValueError(f"Invalid rank: {rank}") + if not 0.0 <= momentum_beta < 1.0: + raise ValueError(f"Invalid momentum_beta: {momentum_beta}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay: {weight_decay}") + if num_ns_steps < 1: + raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") + if num_power_iter < 1: + raise ValueError(f"num_power_iter must be at least 1, got {num_power_iter}") + + self.fp32_matmul_prec = fp32_matmul_prec + self.weight_decay_method = weight_decay_method + self.rank = rank + self.num_power_iter = num_power_iter + + # Create orthogonalization function following OrthogonalizedOptimizer pattern + def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: + logging.debug(f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient") + return muon_utils.newton_schulz( + grad, + steps=num_ns_steps, + coefficient_type=coefficient_type, + ) + + self.scaled_orthogonalize_fn = scaled_orthogonalize_fn + + defaults = dict( + lr=lr, + momentum_beta=momentum_beta, + weight_decay=weight_decay, + ) + + super().__init__(params, defaults) + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Args: + closure: A closure that reevaluates the model and returns the loss. + """ + if closure is None: + loss = None + else: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if p.ndim != 2: + raise ValueError(f"Spectron only supports 2D parameters, got shape {p.shape}") + + grad = p.grad + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + + if state["step"] == 0: + assert all( + key not in state for key in ["factor_A", "factor_B", "momentum_A", "momentum_B", "u_A", "u_B"] + ), ( + "factor_A, factor_B, momentum_A, momentum_B, u_A, u_B should not be initialized at step 0. " + "Some mismatch has been created likely in checkpointing" + ) + self._initialize_state(p, state) + + state["step"] += 1 + + # Get state variables + factor_A = state["factor_A"] + factor_B = state["factor_B"] + momentum_A = state["momentum_A"] + momentum_B = state["momentum_B"] + u_A = state["u_A"] + u_B = state["u_B"] + + # Compute gradients for A and B from parameter gradient + # Using chain rule: ∂L/∂A = ∂L/∂W @ B, ∂L/∂B = ∂L/∂W^T @ A + with utils.fp32_matmul_precision("highest"): + grad_A = grad @ factor_B # shape: (m, r) + grad_B = grad.mT @ factor_A # shape: (n, r) + + # Apply weight decay + self._apply_weight_decay_inplace(factor_A, grad_A, group["lr"], group["weight_decay"]) + self._apply_weight_decay_inplace(factor_B, grad_B, group["lr"], group["weight_decay"]) + + # Update momentum buffers (EMA of gradients) + momentum_A.lerp_(grad_A, 1 - group["momentum_beta"]) + momentum_B.lerp_(grad_B, 1 - group["momentum_beta"]) + + # Orthogonalize momentum using Newton-Schulz + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + orth_momentum_A = self.scaled_orthogonalize_fn(momentum_A) + orth_momentum_B = self.scaled_orthogonalize_fn(momentum_B) + + with utils.fp32_matmul_precision("highest"): + # Estimate spectral radius using power iteration + sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter) + sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter) + + # Update power iteration vectors + state["u_A"] = u_A + state["u_B"] = u_B + + # Compute scaled learning rate + scaled_lr = group["lr"] / (sigma_A + sigma_B + 1.0) + + # Update low-rank factors + factor_A.add_(orth_momentum_A, alpha=-scaled_lr) + factor_B.add_(orth_momentum_B, alpha=-scaled_lr) + + # Reconstruct full weight matrix: W = A @ B^T + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + p.copy_(factor_A @ factor_B.mT) + + return loss + + def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> None: + """Initialize low-rank factors and state for a parameter. + + Args: + p: The parameter tensor (shape: m × n) + state: The state dictionary for this parameter + """ + m, n = p.shape + r = min(self.rank, m, n) # Ensure rank doesn't exceed dimensions + + # Initialize A and B using SVD of the parameter + # This provides a good initialization close to the original weights + # Low-rank factors are stored in fp32 for numerical stability + with torch.no_grad(): + U, S, Vh = torch.linalg.svd(p.float(), full_matrices=False) + # Keep only top r singular values/vectors + sqrt_S = torch.sqrt(S[:r]) + factor_A = U[:, :r] * sqrt_S + factor_B = Vh[:r, :].mT * sqrt_S + + state["factor_A"] = factor_A + state["factor_B"] = factor_B + # Momentum buffers are always stored in fp32 for numerical stability + state["momentum_A"] = torch.zeros_like(factor_A, dtype=torch.float32) + state["momentum_B"] = torch.zeros_like(factor_B, dtype=torch.float32) + + # Initialize power iteration vectors (normalized random vectors in fp32) + u_A = torch.randn(m, dtype=torch.float32, device=p.device) + u_A = u_A / u_A.norm() + u_B = torch.randn(n, dtype=torch.float32, device=p.device) + u_B = u_B / u_B.norm() + + state["u_A"] = u_A + state["u_B"] = u_B + + def _power_iteration(self, X: torch.Tensor, u: torch.Tensor, num_iters: int) -> tuple[torch.Tensor, torch.Tensor]: + """Estimate the largest singular value using power iteration. + + Args: + X: The matrix to estimate largest singular value for + u: The current approximation of the dominant left singular vector + num_iters: Number of power iteration steps + + Returns: + Tuple of (largest singular value, updated_u) + """ + # power_iteration returns (sigma, u, v) but Spectron only needs sigma and u (left singular vector) + sigma, u, _v = power_iteration(X, u, k=num_iters) + return sigma, u diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index b139ae39..7a121ea7 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -22,9 +22,60 @@ "met_approx_eigvals_criteria", "conjugate", "orthogonal_iteration", + "power_iteration", ] +def power_iteration( + W: torch.Tensor, + u: torch.Tensor, + k: int = 1, + eps: float = 1e-8, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Approximate largest singular value and left/right singular vectors using power iteration. + + Implements Algorithm 3 from the Spectron paper (https://arxiv.org/abs/2602.12429). This method iteratively refines + estimates of the dominant singular value and corresponding left and right singular vectors + of a matrix W. + + Args: + W: Matrix of shape (p, q) to analyze + u: Initial left singular vector of shape (p,), should be normalized + k: Number of power iteration steps. Default: 1 + eps: Small constant for numerical stability. Default: 1e-8 + + Returns: + Tuple of (sigma, u, v) where: + - sigma: Approximation of the largest singular value (scalar tensor) + - u: Updated left singular vector of shape (p,) + - v: Updated right singular vector of shape (q,) + """ + # Ensure initial normalization + u = u / u.norm(p=2).clamp_min(eps) + + # Power iteration loop + for _ in range(k): + # v ← W^T u (right vector) + v = W.mT @ u + + # v ← v / ||v||_2 (normalize right vector) + v = v / v.norm(p=2).clamp_min(eps) + + # u ← W v (left vector) + u = W @ v + + # u ← u / ||u||_2 (normalize left vector) + u = u / u.norm(p=2).clamp_min(eps) + + # σ ← u^T W v (Rayleigh quotient approximation) + v = W.mT @ u + v = v / v.norm(p=2).clamp_min(eps) + sigma = u @ (W @ v) + + # Return σ, u, and v + return sigma, u, v + + def eigh_with_fallback( x: Tensor, force_double: bool = False, diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index ca6a1a9b..8f6964bf 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -17,6 +17,7 @@ error=0 torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cpu -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1 exit "${error}" diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 25866cea..02d29a92 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -19,6 +19,7 @@ error=0 coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cuda -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -v -2 || error=1 diff --git a/tests/test_spectron.py b/tests/test_spectron.py new file mode 100644 index 00000000..f5c3d3c3 --- /dev/null +++ b/tests/test_spectron.py @@ -0,0 +1,375 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from absl import flags +from absl.testing import absltest, parameterized + +from emerging_optimizers.orthogonalized_optimizers import spectron +from emerging_optimizers.utils.eig import power_iteration + + +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + + +class PowerIterationTest(parameterized.TestCase): + @parameterized.parameters( + {"shape": (10, 8), "k": 1}, + {"shape": (32, 16), "k": 5}, + {"shape": (64, 32), "k": 10}, + {"shape": (100, 50), "k": 20}, + ) + def test_power_iteration_converges_to_largest_singular_value(self, shape, k) -> None: + """Test that power iteration approximates the largest singular value.""" + # Create a random matrix with known singular values + W = torch.randn(shape, dtype=torch.float32, device=FLAGS.device) + + # Get ground truth largest singular value using SVD + _, S, _ = torch.linalg.svd(W, full_matrices=False) + true_sigma_max = S[0].item() + + # Initialize random left singular vector + u = torch.randn(shape[0], dtype=torch.float32, device=FLAGS.device) + u = u / u.norm() + + # Run power iteration + sigma_approx, u_out, _v_out = power_iteration(W, u, k=k) + + # Check that approximation is close to true value + # More iterations should give better approximation + rel_error = abs(sigma_approx.item() - true_sigma_max) / true_sigma_max + + # With more iterations, error should be smaller + if k >= 10: + self.assertLess(rel_error, 0.01, f"Relative error {rel_error} too large with {k} iterations") + else: + self.assertLess(rel_error, 0.1, f"Relative error {rel_error} too large with {k} iterations") + + def test_power_iteration_output_normalized(self) -> None: + """Test that power iteration returns normalized left singular vector.""" + W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) + u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) + + _, u_out, _v_out = power_iteration(W, u, k=5) + + # Check that output is normalized + torch.testing.assert_close( + u_out.norm(), + torch.tensor(1.0, device=FLAGS.device), + atol=1e-6, + rtol=1e-6, + ) + + def test_power_iteration_handles_unnormalized_input(self) -> None: + """Test that power iteration works even with unnormalized input.""" + W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) + u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) * 100 # Unnormalized + + # Should not raise error and should normalize internally + sigma, u_out, _v_out = power_iteration(W, u, k=5) + + self.assertIsInstance(sigma.item(), float) + torch.testing.assert_close( + u_out.norm(), + torch.tensor(1.0, device=FLAGS.device), + atol=1e-6, + rtol=1e-6, + ) + + def test_power_iteration_deterministic(self) -> None: + """Test that power iteration is deterministic given same inputs.""" + W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) + u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) + + sigma1, u1, _v1 = power_iteration(W, u.clone(), k=5) + sigma2, u2, _v2 = power_iteration(W, u.clone(), k=5) + + torch.testing.assert_close(sigma1, sigma2, atol=0, rtol=0) + torch.testing.assert_close(u1, u2, atol=0, rtol=0) + + def test_power_iteration_returns_both_singular_vectors(self) -> None: + """Test that power iteration returns both left and right singular vectors normalized.""" + W = torch.randn(20, 15, dtype=torch.float32, device=FLAGS.device) + u = torch.randn(20, dtype=torch.float32, device=FLAGS.device) + + sigma, u_out, v_out = power_iteration(W, u, k=10) + + # Both singular vectors should be normalized + torch.testing.assert_close( + u_out.norm(), + torch.tensor(1.0, device=FLAGS.device), + atol=1e-6, + rtol=1e-6, + ) + torch.testing.assert_close( + v_out.norm(), + torch.tensor(1.0, device=FLAGS.device), + atol=1e-6, + rtol=1e-6, + ) + + # Check that W @ v ≈ sigma * u (definition of singular vectors) + Wv = W @ v_out + sigma_u = sigma * u_out + torch.testing.assert_close(Wv, sigma_u, atol=1e-4, rtol=1e-4) + + # Check that W^T @ u ≈ sigma * v + WTu = W.mT @ u_out + sigma_v = sigma * v_out + torch.testing.assert_close(WTu, sigma_v, atol=1e-4, rtol=1e-4) + + +class SpectronTest(parameterized.TestCase): + @parameterized.product( + shape=[(10, 8), (32, 16), (64, 32)], + rank=[4, 8, 16], + weight_decay_method=["decoupled", "independent", "l2"], + fp32_matmul_prec=["highest", "medium"], + ) + def test_smoke(self, shape, rank, weight_decay_method, fp32_matmul_prec) -> None: + """Smoke test Spectron optimizer with various configurations.""" + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=rank, + weight_decay=0.01, + weight_decay_method=weight_decay_method, + fp32_matmul_prec=fp32_matmul_prec, + ) + spectron_opt.step() + + # Check that parameter was updated + self.assertIsNotNone(test_param.data) + self.assertEqual(test_param.shape, shape) + + @parameterized.parameters( + {"shape": (32, 16), "rank": 8}, + {"shape": (64, 32), "rank": 16}, + {"shape": (100, 50), "rank": 20}, + ) + def test_low_rank_reconstruction_quality(self, shape, rank) -> None: + """Test that low-rank factorization preserves parameter reasonably after initialization.""" + # Create parameter with known structure + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + original_param = test_param.data.clone() + + spectron_opt = spectron.Spectron( + [test_param], + lr=0.0, # No update, just check initialization + rank=rank, + momentum_beta=0.0, + weight_decay=0.0, + ) + + # Initialize state + test_param.grad = torch.randn_like(test_param) + spectron_opt.step() + + # Get state + state = spectron_opt.state[test_param] + factor_A = state["factor_A"] + factor_B = state["factor_B"] + + # Reconstruct should give back the parameter (since lr=0) + reconstructed = factor_A @ factor_B.mT + + # Check reconstruction quality (won't be perfect due to low-rank) + rel_error = (reconstructed - original_param).norm() / original_param.norm() + + # Error should decrease with higher rank + self.assertLess(rel_error.item(), 0.5, f"Reconstruction error {rel_error.item()} too large") + + def test_momentum_accumulation(self) -> None: + """Test that momentum is properly accumulated over multiple steps.""" + shape = (32, 16) + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + + momentum_beta = 0.9 + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=8, + momentum_beta=momentum_beta, + weight_decay=0.0, + ) + + # First step + test_param.grad = torch.ones_like(test_param) + spectron_opt.step() + + state = spectron_opt.state[test_param] + momentum_A_step1 = state["momentum_A"].clone() + momentum_B_step1 = state["momentum_B"].clone() + + # Second step with same gradient + test_param.grad = torch.ones_like(test_param) + spectron_opt.step() + + momentum_A_step2 = state["momentum_A"] + momentum_B_step2 = state["momentum_B"] + + # Momentum should have changed (accumulated) + self.assertFalse(torch.allclose(momentum_A_step1, momentum_A_step2)) + self.assertFalse(torch.allclose(momentum_B_step1, momentum_B_step2)) + + def test_spectral_scaling_reduces_lr_for_large_sigma(self) -> None: + """Test that learning rate is scaled down when spectral radius is large.""" + shape = (32, 16) + + # Create parameter with large norm (will have large spectral radius) + test_param_large = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device) * 10) + test_param_small = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device) * 0.1) + + test_param_large.grad = torch.ones_like(test_param_large) * 0.01 + test_param_small.grad = torch.ones_like(test_param_small) * 0.01 + + lr = 0.1 + + opt_large = spectron.Spectron([test_param_large], lr=lr, rank=8, momentum_beta=0.0) + opt_small = spectron.Spectron([test_param_small], lr=lr, rank=8, momentum_beta=0.0) + + opt_large.step() + opt_small.step() + + # Get effective learning rates from spectral scaling + state_large = opt_large.state[test_param_large] + state_small = opt_small.state[test_param_small] + + # Compute sigma values after step + sigma_A_large, _, _ = power_iteration(state_large["factor_A"], state_large["u_A"], k=1) + sigma_B_large, _, _ = power_iteration(state_large["factor_B"], state_large["u_B"], k=1) + + sigma_A_small, _, _ = power_iteration(state_small["factor_A"], state_small["u_A"], k=1) + sigma_B_small, _, _ = power_iteration(state_small["factor_B"], state_small["u_B"], k=1) + + scaled_lr_large = lr / (sigma_A_large + sigma_B_large + 1.0) + scaled_lr_small = lr / (sigma_A_small + sigma_B_small + 1.0) + + # Larger spectral radius should result in smaller effective learning rate + self.assertLess(scaled_lr_large.item(), scaled_lr_small.item()) + + def test_rank_capped_by_dimensions(self) -> None: + """Test that rank is automatically capped by matrix dimensions.""" + shape = (10, 8) # Small matrix + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + # Request rank larger than min dimension + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=100, # Larger than both dimensions + ) + spectron_opt.step() + + state = spectron_opt.state[test_param] + factor_A = state["factor_A"] + factor_B = state["factor_B"] + + # Rank should be capped at min(m, n) = 8 + self.assertEqual(factor_A.shape[1], 8) + self.assertEqual(factor_B.shape[1], 8) + + def test_raises_error_for_1d_params(self) -> None: + """Test that Spectron raises error for 1D parameters.""" + test_param = nn.Parameter(torch.randn(10, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + spectron_opt = spectron.Spectron([test_param], lr=0.01, rank=4) + + with self.assertRaises(ValueError): + spectron_opt.step() + + @parameterized.parameters( + {"num_ns_steps": 1}, + {"num_ns_steps": 3}, + {"num_ns_steps": 5}, + {"num_ns_steps": 10}, + ) + def test_different_ns_steps(self, num_ns_steps) -> None: + """Test that different numbers of Newton-Schulz steps work.""" + shape = (32, 16) + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=8, + num_ns_steps=num_ns_steps, + ) + + # Should not raise error + spectron_opt.step() + + @parameterized.parameters( + {"num_power_iter": 1}, + {"num_power_iter": 5}, + {"num_power_iter": 10}, + ) + def test_different_power_iter_steps(self, num_power_iter) -> None: + """Test that different numbers of power iteration steps work.""" + shape = (32, 16) + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randn_like(test_param) + + spectron_opt = spectron.Spectron( + [test_param], + lr=0.01, + rank=8, + num_power_iter=num_power_iter, + ) + + # Should not raise error + spectron_opt.step() + + def test_state_persistence_across_steps(self) -> None: + """Test that optimizer state (A, B, momentum, u) persists correctly across steps.""" + shape = (32, 16) + test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=FLAGS.device)) + + spectron_opt = spectron.Spectron([test_param], lr=0.01, rank=8) + + # First step + test_param.grad = torch.randn_like(test_param) + spectron_opt.step() + + state = spectron_opt.state[test_param] + factor_A_step1 = state["factor_A"].clone() + u_A_step1 = state["u_A"].clone() + + # Second step + test_param.grad = torch.randn_like(test_param) + spectron_opt.step() + + # State should still exist and be updated + self.assertIn("factor_A", state) + self.assertIn("u_A", state) + + # Values should have changed + self.assertFalse(torch.allclose(state["factor_A"], factor_A_step1)) + # u vector should be updated (but might be similar due to slow changes) + self.assertEqual(state["u_A"].shape, u_A_step1.shape) + + +if __name__ == "__main__": + absltest.main()