Skip to content

Commit d7bda2a

Browse files
Merge pull request #148 from PanZezhong1725/debug_causal_softmax
success debug bang causal softmax
2 parents 7062ef4 + dc10e20 commit d7bda2a

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

src/ops/causal_softmax/bang/causal_softmax_bang.mlu

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -707,22 +707,21 @@ __mlu_global__ void causal_softmaxDim_3(T *destination, int strideD_f, int strid
707707
__bang_write_value(src, dimS, -INFINITY);
708708
__bang_write_zero(destSumFinal, wSize);
709709
int lastI = i % middle;
710-
__memcpy(src, destination + indd, (mask + 1 + lastI) * sizeof(T), GDRAM2NRAM);
710+
__memcpy(src, destination + indd, (mask + 1 + lastI) * sizeof(T), GDRAM2NRAM);//长度为dimsize的向量,只考虑前面mask + 1 + lastI部分的softmax
711711
__bang_argmax(srcMax, src, dimS);
712-
__bang_write_value(destSum, dimS, srcMax[0]);
713-
__memcpy(destSum, src, (mask + 1 + lastI) * sizeof(T), NRAM2NRAM);
714-
__bang_sub_scalar(destSum, destSum, srcMax[0], dimS);
715-
__bang_active_exp_less_0(destSum, destSum, dimS);
716-
__bang_write_zero(src, dimS);
717-
__memcpy(src, destSum, (mask + 1 + lastI) * sizeof(T), NRAM2NRAM);
712+
__bang_write_zero(destSum, dimS);
713+
__memcpy(destSum, src, (mask + 1 + lastI) * sizeof(T), NRAM2NRAM);//初始化destSum为0,前面mask + 1 + lastI部分元素和src保持一致
714+
__bang_sub_scalar(destSum, destSum, srcMax[0], mask + 1 + lastI);//前面mask + 1 + lastI元素减去最大值M,后面的元素还是0
715+
__bang_active_exp_less_0(destSum, destSum, mask + 1 + lastI);//前面mask + 1 + lastI元素做指数变换,后面的元素还是0
716+
__memcpy(src, destSum, dimS * sizeof(T), NRAM2NRAM);
718717
int segNum = dimS / wSize;//准备数值求和
719718
for (int strip = segNum / 2; strip > 0; strip = strip / 2) {
720719
for (int j = 0; j < strip; j++) {
721720
__bang_add(destSum + j * wSize, destSum + j * wSize, destSum + (j + strip) * wSize, wSize);
722721
}
723722
}
724-
__bang_reduce_sum(destSumFinal, destSum, wSize); //此时destSum[0]保存的就是当前maxNum长度数据的数值和
725-
T globalSumInv = 1.0 / (destSumFinal[0] - (dimS - (mask + 1 + lastI)));//下面开始指数变换,写回GDRAM
723+
__bang_reduce_sum(destSumFinal, destSum, wSize); //此时destSumFinal[0]存储的是前面mask + 1 + lastI的sum
724+
T globalSumInv = 1.0 / destSumFinal[0];
726725
__bang_mul_scalar(src, src, globalSumInv, dimS);
727726

728727
__memcpy(destination + indd, src, dimsize * sizeof(T), NRAM2GDRAM);

src/ops/causal_softmax/operator.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor(
3636
#endif
3737
#ifdef ENABLE_CAMBRICON_MLU
3838
case DevCambriconMlu: {
39-
// return bangCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxBangDescriptor_t *) desc_ptr, y_desc);
40-
return cnnlCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxCnnlDescriptor_t *) desc_ptr, y_desc);
39+
return bangCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxBangDescriptor_t *) desc_ptr, y_desc);
40+
// return cnnlCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxCnnlDescriptor_t *) desc_ptr, y_desc);
4141
}
4242
#endif
4343
#ifdef ENABLE_ASCEND_NPU
@@ -63,8 +63,8 @@ __C infiniopStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmax
6363
#endif
6464
#ifdef ENABLE_CAMBRICON_MLU
6565
case DevCambriconMlu: {
66-
// return bangGetCausalSoftmaxWorkspaceSize((CausalSoftmaxBangDescriptor_t) desc, size);
67-
return cnnlGetCausalSoftmaxWorkspaceSize((CausalSoftmaxCnnlDescriptor_t) desc, size);
66+
return bangGetCausalSoftmaxWorkspaceSize((CausalSoftmaxBangDescriptor_t) desc, size);
67+
// return cnnlGetCausalSoftmaxWorkspaceSize((CausalSoftmaxCnnlDescriptor_t) desc, size);
6868
}
6969

7070
#endif
@@ -91,8 +91,8 @@ __C infiniopStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t des
9191
#endif
9292
#ifdef ENABLE_CAMBRICON_MLU
9393
case DevCambriconMlu: {
94-
// return bangCausalSoftmax((CausalSoftmaxBangDescriptor_t) desc, workspace, workspace_size, data, stream);
95-
return cnnlCausalSoftmax((CausalSoftmaxCnnlDescriptor_t) desc, workspace, workspace_size, data, stream);
94+
return bangCausalSoftmax((CausalSoftmaxBangDescriptor_t) desc, workspace, workspace_size, data, stream);
95+
// return cnnlCausalSoftmax((CausalSoftmaxCnnlDescriptor_t) desc, workspace, workspace_size, data, stream);
9696
}
9797
#endif
9898
#ifdef ENABLE_ASCEND_NPU
@@ -118,8 +118,8 @@ __C infiniopStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftma
118118
#endif
119119
#ifdef ENABLE_CAMBRICON_MLU
120120
case DevCambriconMlu: {
121-
// return bangDestroyCausalSoftmaxDescriptor((CausalSoftmaxBangDescriptor_t) desc);
122-
return cnnlDestroyCausalSoftmaxDescriptor((CausalSoftmaxCnnlDescriptor_t) desc);
121+
return bangDestroyCausalSoftmaxDescriptor((CausalSoftmaxBangDescriptor_t) desc);
122+
// return cnnlDestroyCausalSoftmaxDescriptor((CausalSoftmaxCnnlDescriptor_t) desc);
123123
}
124124
#endif
125125
#ifdef ENABLE_ASCEND_NPU

0 commit comments

Comments
 (0)