@@ -1827,6 +1827,43 @@ def test_stride_for_index_Tensor(self):
18271827
18281828 self .assertEqual (out .stride (), f_out .stride ())
18291829
1830+
1831+ @parametrize ("in_dtype" , [torch .float32 , torch .float16 ])
1832+ @parametrize ("bias_dtype" , [torch .float32 , torch .float16 , None ])
1833+ def test_mixed_dtype_for_native_layer_norm_backward (self , in_dtype , bias_dtype ):
1834+ if in_dtype == torch .float16 and bias_dtype == torch .float32 :
1835+ self .skipTest (f"not supported input dtype is { in_dtype } and bias dtype is { bias_dtype } " )
1836+ device = "meta"
1837+
1838+ def fn (input , weight , bias , need_grad_input ):
1839+ outputs = torch .nn .functional .layer_norm (input , input .shape [- 1 :], weight , bias )
1840+ grad_outs = torch .ones_like (outputs )
1841+ grad_ins = torch .autograd .grad (outputs , need_grad_input , grad_outs )
1842+ return grad_ins
1843+
1844+ input = torch .randn ([4 , 8 , 5 ], dtype = in_dtype , device = device , requires_grad = True )
1845+ need_grad_input = [input ]
1846+
1847+ if bias_dtype :
1848+ weight = torch .randn (
1849+ [5 ], dtype = bias_dtype , device = device , requires_grad = True
1850+ )
1851+ bias = torch .randn (
1852+ [5 ], dtype = bias_dtype , device = device , requires_grad = True
1853+ )
1854+ need_grad_input .append (weight )
1855+ need_grad_input .append (bias )
1856+ else :
1857+ weight = None
1858+ bias = None
1859+
1860+ outs = fn (input , weight , bias , need_grad_input )
1861+ out_dtype = [t .dtype for t in outs ]
1862+ if bias_dtype :
1863+ self .assertEqual (out_dtype , [in_dtype , bias_dtype , bias_dtype ])
1864+ else :
1865+ self .assertEqual (out_dtype , [in_dtype ,])
1866+
18301867instantiate_device_type_tests (TestMeta , globals ())
18311868
18321869def print_op_str_if_not_supported (op_str ):
0 commit comments