-
Notifications
You must be signed in to change notification settings - Fork 21
fixed bug, separated function #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -141,6 +141,47 @@ def test_oblique_sgd_zero_gradient(self) -> None: | |||||||||||||||||||||||||
| expected_norms = torch.ones_like(final_norms) | ||||||||||||||||||||||||||
| torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def test_oblique_sgd_momentum_buffer_accumulates_across_steps(self) -> None: | ||||||||||||||||||||||||||
| """Test that ObliqueSGD persists momentum state across optimization steps.""" | ||||||||||||||||||||||||||
| param = torch.tensor( | ||||||||||||||||||||||||||
| [[1.0, 0.0], [0.0, 1.0]], | ||||||||||||||||||||||||||
| dtype=torch.float32, | ||||||||||||||||||||||||||
| device=self.device, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| param = torch.nn.Parameter(param) | ||||||||||||||||||||||||||
| optimizer = ObliqueSGD([param], lr=0.1, momentum=0.8, dim=0) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| first_grad = torch.tensor( | ||||||||||||||||||||||||||
| [[1.0, 2.0], [3.0, 4.0]], | ||||||||||||||||||||||||||
| dtype=torch.float32, | ||||||||||||||||||||||||||
| device=self.device, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| second_grad = torch.tensor( | ||||||||||||||||||||||||||
| [[0.5, 1.5], [2.5, 3.5]], | ||||||||||||||||||||||||||
| dtype=torch.float32, | ||||||||||||||||||||||||||
| device=self.device, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| param.grad = first_grad.clone() | ||||||||||||||||||||||||||
| optimizer.step() | ||||||||||||||||||||||||||
| torch.testing.assert_close( | ||||||||||||||||||||||||||
| optimizer.state[param]["momentum_buffer"], | ||||||||||||||||||||||||||
| first_grad, | ||||||||||||||||||||||||||
| atol=0, | ||||||||||||||||||||||||||
| rtol=0, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| param.grad = second_grad.clone() | ||||||||||||||||||||||||||
| optimizer.step() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| expected_buffer = second_grad + 0.8 * first_grad | ||||||||||||||||||||||||||
mkhona-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
| torch.testing.assert_close( | ||||||||||||||||||||||||||
| optimizer.state[param]["momentum_buffer"], | ||||||||||||||||||||||||||
| expected_buffer, | ||||||||||||||||||||||||||
| atol=0, | ||||||||||||||||||||||||||
| rtol=0, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
Comment on lines
+178
to
+183
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The second momentum-buffer assertion uses Consider using a small tolerance to make the test more robust:
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def test_oblique_adam_zero_gradient(self) -> None: | ||||||||||||||||||||||||||
| """Test that ObliqueAdam handles zero gradients correctly.""" | ||||||||||||||||||||||||||
| matrix_size = (2, 3) | ||||||||||||||||||||||||||
|
|
@@ -217,7 +258,11 @@ def test_oblique_adam_large_gradient(self) -> None: | |||||||||||||||||||||||||
| rtol=1e-6, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def test_multiple_optimization_steps_preserve_norms(self) -> None: | ||||||||||||||||||||||||||
| @parameterized.parameters( | ||||||||||||||||||||||||||
| (0.4), | ||||||||||||||||||||||||||
| (0.8), | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| def test_multiple_optimization_steps_preserve_norms(self, momentum: float) -> None: | ||||||||||||||||||||||||||
| """Test that norms are preserved across multiple optimization steps.""" | ||||||||||||||||||||||||||
| matrix_size = (4, 4) | ||||||||||||||||||||||||||
| param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) | ||||||||||||||||||||||||||
|
|
@@ -226,7 +271,7 @@ def test_multiple_optimization_steps_preserve_norms(self) -> None: | |||||||||||||||||||||||||
| param = param / param.norm(dim=0, keepdim=True).clamp(min=1e-8) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| param = torch.nn.Parameter(param) | ||||||||||||||||||||||||||
| optimizer = ObliqueSGD([param], lr=0.05, momentum=0.8, dim=0) | ||||||||||||||||||||||||||
| optimizer = ObliqueSGD([param], lr=0.05, momentum=momentum, dim=0) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Perform multiple optimization steps | ||||||||||||||||||||||||||
| for step in range(10): | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dtype=torch.float32 is not necessary. default dtype is almost never changed.