Skip to content

Commit 78841bf

Browse files
authored
[RFC] Require bitwise equivalence for SimpleFSDP numerics (#1743)
This PR appears to pass. I am not sure if it's wise though. Signed-off-by: Edward Z. Yang <[email protected]>
1 parent 8d20f02 commit 78841bf

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchtitan/experiments/simple_fsdp/tests/test_numerics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_replicate_convergence(self):
9393
for fsdp2_loss, simple_fsdp_replicate_loss in zip(
9494
fsdp2_losses, simple_fsdp_replicate_losses
9595
):
96-
assert torch.allclose(fsdp2_loss, simple_fsdp_replicate_loss)
96+
assert torch.equal(fsdp2_loss, simple_fsdp_replicate_loss)
9797

9898
def test_fullyshard_convergence(self):
9999
# unit test for fully_shard mode
@@ -109,7 +109,7 @@ def test_fullyshard_convergence(self):
109109
for fsdp2_loss, simple_fsdp_fullyshard_loss in zip(
110110
fsdp2_losses, simple_fsdp_fullyshard_losses
111111
):
112-
assert torch.allclose(fsdp2_loss, simple_fsdp_fullyshard_loss)
112+
assert torch.equal(fsdp2_loss, simple_fsdp_fullyshard_loss)
113113

114114
def test_hybridshard_convergence(self):
115115
# unit test for hybrid_shard mode
@@ -125,4 +125,4 @@ def test_hybridshard_convergence(self):
125125
for fsdp2_loss, simple_fsdp_hybridshard_loss in zip(
126126
fsdp2_losses, simple_fsdp_hybridshard_losses
127127
):
128-
assert torch.allclose(fsdp2_loss, simple_fsdp_hybridshard_loss)
128+
assert torch.equal(fsdp2_loss, simple_fsdp_hybridshard_loss)

0 commit comments

Comments
 (0)