@@ -30,7 +30,10 @@ class RandomSampleDescriptor(Structure):
3030
3131
3232def random_sample (data , random_val , topp , topk , voc , temperature , torch_device ):
33- indices = torch .zeros ([topk ], dtype = torch .int64 )
33+ if (torch_device == "cuda" ):
34+ indices = torch .zeros ([topk ], dtype = torch .uint64 )
35+ else :
36+ indices = torch .zeros ([topk ], dtype = torch .int64 )
3437 dataNp = data .clone ().detach ()
3538 sorted_indices = torch .arange (voc )
3639
@@ -52,7 +55,7 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
5255
5356 globalM = dataNp [0 ]
5457 dataNp = (dataNp - globalM ) / temperature
55- dataNp = torch .softmax (dataNp . float () , dim = 0 )
58+ dataNp = torch .softmax (dataNp , dim = 0 )
5659 sum_s = 0
5760 for end in range (topk ):
5861 sum_s += dataNp [end ]
@@ -88,15 +91,15 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
8891 ans = random_sample (data .to ("cpu" ), random_val , topp , topk , voc , temperature , "cpu" )
8992 else :
9093 ans = random_sample_0 (data )
91- if (torch_device == 'mlu' or torch_device == 'npu' ):
94+ if (torch_device != "cuda" ):
9295
9396 indices = torch .zeros ([1 ], dtype = torch .int64 ).to (torch_device )
9497 else :
9598
9699 indices = torch .zeros ([1 ], dtype = torch .uint64 ).to (torch_device )
97100 x_tensor = to_tensor (data , lib )
98101 indices_tensor = to_tensor (indices , lib )
99- if (torch_device == 'mlu' or torch_device == 'npu' ):
102+ if (torch_device == 'mlu' ):
100103 indices_tensor .descriptor .contents .dt = U64 # treat int64 as uint64
101104
102105
0 commit comments