@@ -93,7 +93,7 @@ def test_replicate_convergence(self):
93
93
for fsdp2_loss , simple_fsdp_replicate_loss in zip (
94
94
fsdp2_losses , simple_fsdp_replicate_losses
95
95
):
96
- assert torch .allclose (fsdp2_loss , simple_fsdp_replicate_loss )
96
+ assert torch .equal (fsdp2_loss , simple_fsdp_replicate_loss )
97
97
98
98
def test_fullyshard_convergence (self ):
99
99
# unit test for fully_shard mode
@@ -109,7 +109,7 @@ def test_fullyshard_convergence(self):
109
109
for fsdp2_loss , simple_fsdp_fullyshard_loss in zip (
110
110
fsdp2_losses , simple_fsdp_fullyshard_losses
111
111
):
112
- assert torch .allclose (fsdp2_loss , simple_fsdp_fullyshard_loss )
112
+ assert torch .equal (fsdp2_loss , simple_fsdp_fullyshard_loss )
113
113
114
114
def test_hybridshard_convergence (self ):
115
115
# unit test for hybrid_shard mode
@@ -125,4 +125,4 @@ def test_hybridshard_convergence(self):
125
125
for fsdp2_loss , simple_fsdp_hybridshard_loss in zip (
126
126
fsdp2_losses , simple_fsdp_hybridshard_losses
127
127
):
128
- assert torch .allclose (fsdp2_loss , simple_fsdp_hybridshard_loss )
128
+ assert torch .equal (fsdp2_loss , simple_fsdp_hybridshard_loss )
0 commit comments