@@ -707,22 +707,21 @@ __mlu_global__ void causal_softmaxDim_3(T *destination, int strideD_f, int strid
707
707
__bang_write_value(src, dimS, -INFINITY);
708
708
__bang_write_zero(destSumFinal, wSize);
709
709
int lastI = i % middle;
710
- __memcpy(src, destination + indd, (mask + 1 + lastI) * sizeof(T), GDRAM2NRAM);
710
+ __memcpy(src, destination + indd, (mask + 1 + lastI) * sizeof(T), GDRAM2NRAM);//长度为dimsize的向量,只考虑前面mask + 1 + lastI部分的softmax
711
711
__bang_argmax(srcMax, src, dimS);
712
- __bang_write_value(destSum, dimS, srcMax[0]);
713
- __memcpy(destSum, src, (mask + 1 + lastI) * sizeof(T), NRAM2NRAM);
714
- __bang_sub_scalar(destSum, destSum, srcMax[0], dimS);
715
- __bang_active_exp_less_0(destSum, destSum, dimS);
716
- __bang_write_zero(src, dimS);
717
- __memcpy(src, destSum, (mask + 1 + lastI) * sizeof(T), NRAM2NRAM);
712
+ __bang_write_zero(destSum, dimS);
713
+ __memcpy(destSum, src, (mask + 1 + lastI) * sizeof(T), NRAM2NRAM);//初始化destSum为0,前面mask + 1 + lastI部分元素和src保持一致
714
+ __bang_sub_scalar(destSum, destSum, srcMax[0], mask + 1 + lastI);//前面mask + 1 + lastI元素减去最大值M,后面的元素还是0
715
+ __bang_active_exp_less_0(destSum, destSum, mask + 1 + lastI);//前面mask + 1 + lastI元素做指数变换,后面的元素还是0
716
+ __memcpy(src, destSum, dimS * sizeof(T), NRAM2NRAM);
718
717
int segNum = dimS / wSize;//准备数值求和
719
718
for (int strip = segNum / 2; strip > 0; strip = strip / 2) {
720
719
for (int j = 0; j < strip; j++) {
721
720
__bang_add(destSum + j * wSize, destSum + j * wSize, destSum + (j + strip) * wSize, wSize);
722
721
}
723
722
}
724
- __bang_reduce_sum(destSumFinal, destSum, wSize); //此时destSum[0]保存的就是当前maxNum长度数据的数值和
725
- T globalSumInv = 1.0 / ( destSumFinal[0] - (dimS - (mask + 1 + lastI)));//下面开始指数变换,写回GDRAM
723
+ __bang_reduce_sum(destSumFinal, destSum, wSize); //此时destSumFinal[0]存储的是前面mask + 1 + lastI的sum
724
+ T globalSumInv = 1.0 / destSumFinal[0];
726
725
__bang_mul_scalar(src, src, globalSumInv, dimS);
727
726
728
727
__memcpy(destination + indd, src, dimsize * sizeof(T), NRAM2GDRAM);
0 commit comments