From cee185e9de94555c784d21da56b085362adc26d8 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 22 Dec 2024 19:42:35 -0500 Subject: [PATCH] sync winograd from sd.cpp --- src/ggml-cuda/conv-winograd.cu | 1921 +++++++++++++++++-------------- src/ggml-cuda/conv-winograd.cuh | 88 +- tests/test-conv2d-winograd.cpp | 4 +- 3 files changed, 1112 insertions(+), 901 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index ca89f1657..9967531a8 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -1,855 +1,1066 @@ -#include "conv-winograd.cuh" -#include "convert.cuh" - -__device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag, float4 accumulator[][16]){ - accumulator[0][0].x += input_frag[0].x*filter_frag[0].x; - accumulator[0][0].y += input_frag[0].y*filter_frag[0].x; - accumulator[0][0].z += input_frag[0].z*filter_frag[0].x; - accumulator[0][0].w += input_frag[0].w*filter_frag[0].x; - - accumulator[0][1].x += input_frag[1].x*filter_frag[0].x; - accumulator[0][1].y += input_frag[1].y*filter_frag[0].x; - accumulator[0][1].z += input_frag[1].z*filter_frag[0].x; - accumulator[0][1].w += input_frag[1].w*filter_frag[0].x; - - accumulator[0][2].x += input_frag[0].x*filter_frag[0].y; - accumulator[0][2].y += input_frag[0].y*filter_frag[0].y; - accumulator[0][2].z += input_frag[0].z*filter_frag[0].y; - accumulator[0][2].w += input_frag[0].w*filter_frag[0].y; - - accumulator[0][3].x += input_frag[1].x*filter_frag[0].y; - accumulator[0][3].y += input_frag[1].y*filter_frag[0].y; - accumulator[0][3].z += input_frag[1].z*filter_frag[0].y; - accumulator[0][3].w += input_frag[1].w*filter_frag[0].y; - - accumulator[0][4].x += input_frag[0].x*filter_frag[0].z; - accumulator[0][4].y += input_frag[0].y*filter_frag[0].z; - accumulator[0][4].z += input_frag[0].z*filter_frag[0].z; - accumulator[0][4].w += input_frag[0].w*filter_frag[0].z; - - accumulator[0][5].x += input_frag[1].x*filter_frag[0].z; - accumulator[0][5].y += input_frag[1].y*filter_frag[0].z; - accumulator[0][5].z += input_frag[1].z*filter_frag[0].z; - accumulator[0][5].w += input_frag[1].w*filter_frag[0].z; - - accumulator[0][6].x += input_frag[0].x*filter_frag[0].w; - accumulator[0][6].y += input_frag[0].y*filter_frag[0].w; - accumulator[0][6].z += input_frag[0].z*filter_frag[0].w; - accumulator[0][6].w += input_frag[0].w*filter_frag[0].w; - - accumulator[0][7].x += input_frag[1].x*filter_frag[0].w; - accumulator[0][7].y += input_frag[1].y*filter_frag[0].w; - accumulator[0][7].z += input_frag[1].z*filter_frag[0].w; - accumulator[0][7].w += input_frag[1].w*filter_frag[0].w; - - // - accumulator[0][8].x += input_frag[0].x*filter_frag[1].x; - accumulator[0][8].y += input_frag[0].y*filter_frag[1].x; - accumulator[0][8].z += input_frag[0].z*filter_frag[1].x; - accumulator[0][8].w += input_frag[0].w*filter_frag[1].x; - - accumulator[0][9].x += input_frag[1].x*filter_frag[1].x; - accumulator[0][9].y += input_frag[1].y*filter_frag[1].x; - accumulator[0][9].z += input_frag[1].z*filter_frag[1].x; - accumulator[0][9].w += input_frag[1].w*filter_frag[1].x; - - accumulator[0][10].x += input_frag[0].x*filter_frag[1].y; - accumulator[0][10].y += input_frag[0].y*filter_frag[1].y; - accumulator[0][10].z += input_frag[0].z*filter_frag[1].y; - accumulator[0][10].w += input_frag[0].w*filter_frag[1].y; - - accumulator[0][11].x += input_frag[1].x*filter_frag[1].y; - accumulator[0][11].y += input_frag[1].y*filter_frag[1].y; - accumulator[0][11].z += input_frag[1].z*filter_frag[1].y; - accumulator[0][11].w += input_frag[1].w*filter_frag[1].y; - - accumulator[0][12].x += input_frag[0].x*filter_frag[1].z; - accumulator[0][12].y += input_frag[0].y*filter_frag[1].z; - accumulator[0][12].z += input_frag[0].z*filter_frag[1].z; - accumulator[0][12].w += input_frag[0].w*filter_frag[1].z; - - accumulator[0][13].x += input_frag[1].x*filter_frag[1].z; - accumulator[0][13].y += input_frag[1].y*filter_frag[1].z; - accumulator[0][13].z += input_frag[1].z*filter_frag[1].z; - accumulator[0][13].w += input_frag[1].w*filter_frag[1].z; - - accumulator[0][14].x += input_frag[0].x*filter_frag[1].w; - accumulator[0][14].y += input_frag[0].y*filter_frag[1].w; - accumulator[0][14].z += input_frag[0].z*filter_frag[1].w; - accumulator[0][14].w += input_frag[0].w*filter_frag[1].w; - - accumulator[0][15].x += input_frag[1].x*filter_frag[1].w; - accumulator[0][15].y += input_frag[1].y*filter_frag[1].w; - accumulator[0][15].z += input_frag[1].z*filter_frag[1].w; - accumulator[0][15].w += input_frag[1].w*filter_frag[1].w; - - ////// - accumulator[1][0].x += input_frag[2].x*filter_frag[2].x; - accumulator[1][0].y += input_frag[2].y*filter_frag[2].x; - accumulator[1][0].z += input_frag[2].z*filter_frag[2].x; - accumulator[1][0].w += input_frag[2].w*filter_frag[2].x; - - accumulator[1][1].x += input_frag[3].x*filter_frag[2].x; - accumulator[1][1].y += input_frag[3].y*filter_frag[2].x; - accumulator[1][1].z += input_frag[3].z*filter_frag[2].x; - accumulator[1][1].w += input_frag[3].w*filter_frag[2].x; - - accumulator[1][2].x += input_frag[2].x*filter_frag[2].y; - accumulator[1][2].y += input_frag[2].y*filter_frag[2].y; - accumulator[1][2].z += input_frag[2].z*filter_frag[2].y; - accumulator[1][2].w += input_frag[2].w*filter_frag[2].y; - - accumulator[1][3].x += input_frag[3].x*filter_frag[2].y; - accumulator[1][3].y += input_frag[3].y*filter_frag[2].y; - accumulator[1][3].z += input_frag[3].z*filter_frag[2].y; - accumulator[1][3].w += input_frag[3].w*filter_frag[2].y; - - accumulator[1][4].x += input_frag[2].x*filter_frag[2].z; - accumulator[1][4].y += input_frag[2].y*filter_frag[2].z; - accumulator[1][4].z += input_frag[2].z*filter_frag[2].z; - accumulator[1][4].w += input_frag[2].w*filter_frag[2].z; - - accumulator[1][5].x += input_frag[3].x*filter_frag[2].z; - accumulator[1][5].y += input_frag[3].y*filter_frag[2].z; - accumulator[1][5].z += input_frag[3].z*filter_frag[2].z; - accumulator[1][5].w += input_frag[3].w*filter_frag[2].z; - - accumulator[1][6].x += input_frag[2].x*filter_frag[2].w; - accumulator[1][6].y += input_frag[2].y*filter_frag[2].w; - accumulator[1][6].z += input_frag[2].z*filter_frag[2].w; - accumulator[1][6].w += input_frag[2].w*filter_frag[2].w; - - accumulator[1][7].x += input_frag[3].x*filter_frag[2].w; - accumulator[1][7].y += input_frag[3].y*filter_frag[2].w; - accumulator[1][7].z += input_frag[3].z*filter_frag[2].w; - accumulator[1][7].w += input_frag[3].w*filter_frag[2].w; - - // - accumulator[1][8].x += input_frag[2].x*filter_frag[3].x; - accumulator[1][8].y += input_frag[2].y*filter_frag[3].x; - accumulator[1][8].z += input_frag[2].z*filter_frag[3].x; - accumulator[1][8].w += input_frag[2].w*filter_frag[3].x; - - accumulator[1][9].x += input_frag[3].x*filter_frag[3].x; - accumulator[1][9].y += input_frag[3].y*filter_frag[3].x; - accumulator[1][9].z += input_frag[3].z*filter_frag[3].x; - accumulator[1][9].w += input_frag[3].w*filter_frag[3].x; - - accumulator[1][10].x += input_frag[2].x*filter_frag[3].y; - accumulator[1][10].y += input_frag[2].y*filter_frag[3].y; - accumulator[1][10].z += input_frag[2].z*filter_frag[3].y; - accumulator[1][10].w += input_frag[2].w*filter_frag[3].y; - - accumulator[1][11].x += input_frag[3].x*filter_frag[3].y; - accumulator[1][11].y += input_frag[3].y*filter_frag[3].y; - accumulator[1][11].z += input_frag[3].z*filter_frag[3].y; - accumulator[1][11].w += input_frag[3].w*filter_frag[3].y; - - accumulator[1][12].x += input_frag[2].x*filter_frag[3].z; - accumulator[1][12].y += input_frag[2].y*filter_frag[3].z; - accumulator[1][12].z += input_frag[2].z*filter_frag[3].z; - accumulator[1][12].w += input_frag[2].w*filter_frag[3].z; - - accumulator[1][13].x += input_frag[3].x*filter_frag[3].z; - accumulator[1][13].y += input_frag[3].y*filter_frag[3].z; - accumulator[1][13].z += input_frag[3].z*filter_frag[3].z; - accumulator[1][13].w += input_frag[3].w*filter_frag[3].z; - - accumulator[1][14].x += input_frag[2].x*filter_frag[3].w; - accumulator[1][14].y += input_frag[2].y*filter_frag[3].w; - accumulator[1][14].z += input_frag[2].z*filter_frag[3].w; - accumulator[1][14].w += input_frag[2].w*filter_frag[3].w; - - accumulator[1][15].x += input_frag[3].x*filter_frag[3].w; - accumulator[1][15].y += input_frag[3].y*filter_frag[3].w; - accumulator[1][15].z += input_frag[3].z*filter_frag[3].w; - accumulator[1][15].w += input_frag[3].w*filter_frag[3].w; - } - -// extern "C" -// { - -__device__ __forceinline__ void transform_output_tile(float * __restrict__ pOutputs, float2 *C_tile, float2 *At, - int round, int c_tensor, int c_glb_offset, int i1, int i2, - unsigned int mask1, unsigned int mask2, int out_w) -{ - - c_tensor += (((round)/2)*32 + ((round)%2)*2)*c_glb_offset; - int x, x1; - - #pragma unroll - for(int j=0; j<4; j++){ - - At[j].x = C_tile[j].x + C_tile[4+j].x + C_tile[8+j].x; - At[j].y = C_tile[j].y + C_tile[4+j].y + C_tile[8+j].y; - - At[4+j].x = C_tile[4+j].x - C_tile[8+j].x - C_tile[12+j].x; - At[4+j].y = C_tile[4+j].y - C_tile[8+j].y - C_tile[12+j].y; - - } - - #pragma unroll - for(int i=0; i<2; i++){ - x = i*4; - x1 = i*((out_w-(out_w%2)) + (out_w%2)/2); - - if(mask1&(1<<(i*2))){ - pOutputs[x1 + c_tensor + i1] = At[x].x + At[x+1].x + At[x+2].x; - } - if(mask2&(1<<(i*2))){ - pOutputs[x1 + c_tensor + i2] = At[x].y + At[x+1].y + At[x+2].y; - } - if(mask1&(1<<(i*2+1))){ - pOutputs[x1 + c_tensor + i1 + 1] = At[x+1].x - At[x+2].x - At[x+3].x; - } - if(mask2&(1<<(i*2+1))){ - pOutputs[x1 + c_tensor + i2 + 1] = At[x+1].y - At[x+2].y - At[x+3].y; - } - } -} - -__device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int tiles_dim_h, - int tw, int th, int out_w, int out_h){ - - unsigned int mask = 0x000F; - // if((blockIdx.y/tiles_dim)==(tiles_dim-1) && out_w%2) mask&=0x0003; // pad bottom row - // if(!((blockIdx.y+1)%tiles_dim) && out_w%2) mask&=0X0005; // pad right col - // if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row - // if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col - if(tiles_dim_w % tw == 0 && tiles_dim_h % th == 0){ - if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row - if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col - }else if(tiles_dim_w % tw == 0){ - int k = out_h % TH; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles - if(blockIdx.y==gridDim.y-1 && (idd / tw) == k1-1 && k%2) mask&=0x0003; // pad bottom row - if(blockIdx.y==gridDim.y-1 && (idd / tw) > k1-1) mask &= 0x0; //pad all zeros since this tile does not exist - }else if(tiles_dim_h % th == 0){ - int k = out_w % TW; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles - if(blockIdx.x==gridDim.x-1 && (idd % tw) == k1-1 && k%2) mask&=0X0005; // pad right col - if(blockIdx.x==gridDim.x-1 && (idd % tw) > k1-1) mask&=0X0; // pad all zeroes - }else{ - int kh = out_h % TH; - int kw = out_w % TW; - int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles - int kw1 = kw % 2 ? (kw+1)/2 : kw/2; - if(blockIdx.y==gridDim.y-1 && (idd / tw) == kh1-1 && kh%2) mask&=0x0003; // pad bottom row - if(blockIdx.x==gridDim.x-1 && (idd % tw) == kw1-1 && kw%2) mask&=0X0005; // pad right col - if(blockIdx.y==gridDim.y-1 && (idd / tw) > kh1-1) mask &= 0x0; //pad all zeros since this tile does not exist - if(blockIdx.x==gridDim.x-1 && (idd % tw) > kw1-1) mask &= 0X0; // pad all zeroes - } - return mask; -} - -__device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float * __restrict__ C, -int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, -float4 *input_frag_mem, float4* filter_frag_mem){ - - float2 *output_smem = (float2 *) shared_mem; - float2 *accumulator = (float2 *) acumm_smem; - // float2 *C_out = (float2*)C; - - float2 *C_tile = (float2*) input_frag_mem; - float2 *At = (float2*) filter_frag_mem; - // for output transformation, the role of threadIdx.y changes again: - // in the main loop, different threadIdx.y deal with different element of the 4x4 tile - // here, they are for 4 different groups of lane ids from optSTS64 layout - // for init (and init+32), we need to identify its tile number (0-31) within the supertile - // first, from init, find out from which threadIdx.x it comes. - - // now we got l, which is the land id which computed accumulated sum for the tile element - // each lane id (or threadIdx.x) computed 8 tiles which are distributed into 4 locations spreading - // over the smem. We need to find which of the 8 the current tile is. - // use tileid table to figure out - - // for 2nd tile - - int idd1 = tileid[0][threadIdx.x]; - int id1 = (idd1 % tw) * 2 + (idd1 / tw) * out_w * 2; - int idd2 = tileid[1][threadIdx.x]; - int id2 = (idd2 % tw) * 2 + (idd2 / tw) * out_w * 2; - - // unsigned short mask1 = 0x000F; - unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); - unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); - - // output transpose step - int t=0; - int acumm1, acumm2; - // For transposing - //acumm1 = access_s_out[Inx]; //* 4 - acumm1 = ((threadIdx.x%8)/2)*34 + threadIdx.x%2 + (threadIdx.x/16)*2 + ((threadIdx.x/8)%2)*8; - acumm2 = acumm1+4; - - int acumm4 = BN_p*16 ; //*4 - int idx = threadIdx.y * BN_p; - int idx2 = idx + BN_p*8; //(BN_p*2 *8)/2 - - // For transformating - int offset = BN_p *2; //*2/2 - int init = ( (threadIdx.y/4)*BN_p*16 + (threadIdx.y%4)*(32+2) ) *2 + threadIdx.x; - - int c_glb_offset = out_h*out_w; - // int c_tensor = blockIdx.z*c_glb_offset*BK + (blockIdx.y%tiles_dim)*2 + (blockIdx.y/tiles_dim)*out_w*2 + - // blockIdx.x*BN + (threadIdx.x%16)*2+ - // ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; - - // int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; - // int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + - // threadIdx.y*(in_h*in_w) - (in_w+1); - - int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * TW + blockIdx.y * out_w * TH + - // (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 + - ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; - - #pragma unroll - for(int round=0; round<4; round++){ - - *( (float2*) (output_smem + idx + acumm1) ) = *(accumulator+t); - *( (float2*) (output_smem + idx + acumm1 + 16) ) = *(accumulator+t+1); // float 4, t - *( (float2*) (output_smem + idx + acumm2) ) = *(accumulator+t+2); - *( (float2*) (output_smem + idx + acumm2 + 16) ) = *(accumulator+t+3); // float 4, t+1 - - - *( (float2*) (output_smem + idx2 + acumm1) ) = *(accumulator+t+32); - *( (float2*) (output_smem + idx2 + acumm1 + 16) ) = *(accumulator+t+33); // float 4, t+16 - *( (float2*) (output_smem + idx2 + acumm2) ) = *(accumulator+t+34); - *( (float2*) (output_smem + idx2 + acumm2 + 16) ) = *(accumulator+t+35); // float 4, t+17 - - // the above 8 float2 will be consumed by theadIdx.y = [0,1,2,3] - - // the following 8 float2 will be consumed by theadIdx.y = [4,5,6,7] - - *( (float2*) (output_smem + idx + acumm4 + acumm1) ) = *(accumulator+t+4); - *( (float2*) (output_smem + idx + acumm4 + acumm1 + 16) ) = *(accumulator+t+5); // float 4, t+2 - *( (float2*) (output_smem + idx + acumm4 + acumm2) ) = *(accumulator+t+6); - *( (float2*) (output_smem + idx + acumm4 + acumm2 + 16) ) = *(accumulator+t+7); // float 4, t+3 - - *( (float2*) (output_smem + idx2 + acumm4 + acumm1) ) = *(accumulator+t+36); - *( (float2*) (output_smem + idx2 + acumm4 + acumm1 + 16) ) = *(accumulator+t+37); // float 4, t+18 - *( (float2*) (output_smem + idx2 + acumm4 + acumm2) ) = *(accumulator+t+38); - *( (float2*) (output_smem + idx2 + acumm4 + acumm2 + 16) ) = *(accumulator+t+39); // float 4, t+19 - - - - t+=8; - - __syncthreads(); - - - - - for(int i=0; i<16; i++){ - C_tile[i].x = shared_mem[i*offset + init]; - C_tile[i].y = shared_mem[i*offset + init + 32]; - - } - - - // transform output tiles - transform_output_tile(C, C_tile, At, round, c_tensor, c_glb_offset, id1, id2, mask1, mask2, out_w); - __syncthreads(); - } -} - - -// Set of functions per row in Gw product -__device__ float f_row1(float * __restrict__ G, int j){ - return G[j]; - } - __device__ float f_row2(float * __restrict__ G, int j){ - return 0.5f*(G[j] + G[6+j] + G[3+j]); - } - __device__ float f_row3(float * __restrict__ G, int j){ - return 0.5f*(G[j] + G[6+j] - G[3+j]); - } - __device__ float f_row4(float * __restrict__ G, int j){ - return G[6+j]; - } - // Set of functions per column in GwGt product - __device__ float f_col1(float * __restrict__ G, int j){ - return G[j]; - } - __device__ float f_col2(float * __restrict__ G, int j){ - return 0.5f*(G[j] + G[j+2] + G[j+1]); - } - __device__ float f_col3(float * __restrict__ G, int j){ - return 0.5f*(G[j] + G[j+2] - G[j+1]); - } - __device__ float f_col4(float * __restrict__ G, int j){ - return G[j+2]; - } - - template - static __device__ __forceinline__ float t2f32(T val) { - return (float) val; - } - - template <> - __device__ float __forceinline__ t2f32(half val) { - return __half2float(val); - } - - typedef float(*pointFunction_t)(float *, int); - - template - __global__ void FX(const T * __restrict__ pInputs, float * __restrict__ pOutputs, int filt_k, - int filt_c, int filt_h, int filt_w){ - - // assumes KCHW layout - int Inx = threadIdx.x, Iny = threadIdx.y; - int TileX = blockIdx.x, TileY = blockIdx.y; - - // int c_glb_offset = filt_k*filt_h*filt_w; - // int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx; - int c_glb_offset = filt_h*filt_w; - // int c_kernel = TileY*BC*c_glb_offset + TileX*BK*filt_c*c_glb_offset + Iny*c_glb_offset+ Inx*filt_c*c_glb_offset; - int c_kernel = (TileY*BC + (TileX*BK+Inx)*filt_c + Iny)*c_glb_offset; - int c_glb_offset_s = filt_k*4*4; - int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx; - - float Gw[21]; //9+12. In registers - float *Gw_buffer = Gw+9; - - pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4}; - pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; - - for(int bk=0; bk= BC) return; - - // each thread in row 0 puts its first element of 1st filter tile(loaded by the thread) in smem - // taking 32 slots - // then puts its first element of 2nd filter tile immediately after, taking another 32 slots - // then followed by threads in row 1, 2.. until 7 - - // Note the next element is BK*BC (8*64) slots away, then another BK*BC .... - // for every 64 values, the first 32 belongs to filter tile 1, the next 32 for filter tile 2 - - for(int k=0; k<2; k++){ // prefetch 2 filter tiles/thread - for(int i=0; i<4; i++){ - #pragma unroll - for(int j=0; j<4; j++){ - pOutputs[c_tensor_s + i*c_offset_s*4 + j*c_offset_s] = tiles[k*16 + i*4 + j]; - } - } - // 2nd tile right behind the 1st? - c_tensor_s += BN; // BN has nothing to do with input tiles - } - -} - -__device__ __forceinline__ void prefetch_filter_tile(const float * __restrict__ pInputs, float * __restrict__ tiles, int filt_k){ - - int c_tensor = blockIdx.z*BK + (threadIdx.y*filt_k<<4) + threadIdx.x; // Iny*filt_k*4*4 - // each threadIdx.y corresponds to one channel; there are 8 different threadIdx.y so 8 channels - - //each thread (32 threads in x direction) loads 2 kernel tiles (32 in K direction apart) - // save the two tiles in a float[32] register, float[16] for each - - int acumm; - #pragma unroll - for(int i=0; i<4; i++){ - acumm = (i*filt_k<<2); - #pragma unroll - for(int j=0; j<4; j++){ - tiles[(i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor]; - tiles[16 + (i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor+BN]; - } - } -} - -__device__ __forceinline__ void prefetch_input_tile(const float * __restrict__ pInputs, float *tile, int in_h, - int in_w, int tw, int th, unsigned short mask){ - - // load one input tile - int tx = TW, ty = TH; - int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; - int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + - threadIdx.y*(in_h*in_w) - (in_w+1); - - int acumm,x; - - - if(mask==0xFFFF){ - #pragma unroll - for(int i=0; i<4; i++){ - acumm = i*in_w; - #pragma unroll - for(int j=0; j<4; j++){ - tile[(i<<2) + j] = pInputs[acumm + j + c_tensor]; - } - } - - } else { - for(int i=0; i<4; i++){ - acumm = i*in_w; - #pragma unroll - for(int j=0; j<4; j++){ - x = (i<<2) + j; - tile[x] = 0.f; - if(mask&(1< k1-1) m &= 0x0; //pad all zeros since this tile does not exist - }else if(tiles_dim_h % Y == 0){ - int k = in_w % TW; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 8*k1 tiles - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == k1-1) m &= (!(k%2))?(0x7777):(0x3333); // pad right col or right 2 cols - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > k1-1) m &= 0x0; //pad all zeros since this tile does not exist - }else{ - int kh = in_h % TH; - int kw = in_w % TW; - int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles - int kw1 = kw % 2 ? (kw+1)/2 : kw/2; - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == kh1-1) m &= (!(kh%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X > kh1-1) m &= 0x0; //pad all zeros since this tile does not exist - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == kw1-1) m &= (!(kw%2))?(0x7777):(0x3333); // pad right col or right 2 cols - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > kw1-1) m &= 0x0; //pad all zeros since this tile does not exist - } - if(blockIdx.x==0 && (threadIdx.x % X) == 0) m &=0xeeee; // pad left col - - float img_tile[16]; // Prefetch input from GMEM - float filter_tile[32]; // Prefetch filter from GMEM - - float4 input_frag_mem[8]; //2*2(2*8/4) Data to do Outer Product + prefetch f. SMEM (double_buffer) - float4 filter_frag_mem[8]; //2*2 Data to do Outer Product + prefetch f. SMEM (double_buffer) - float4 accumulator[2][16] = {0.0f}; // Accumulators - - float4 *A_frag; // Input data pointer - int frag_offset = 2 * (BN*BC); // (2=8/4) SMEM input read offset - - float4 *B_frag; // Filter data pointer - int f_frag_offset = 2 * (BC*BK); // (2=8/4 with 4 being float4) SMEM filter read offset - - - float4 *input_frag = (float4*) input_frag_mem; - float4 *filter_frag = (float4*) filter_frag_mem; - - float4 *swap_filter; - float4 *swap_input; - - prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); - prefetch_filter_tile(B, filter_tile, filt_k); - - float4 *input_frag_buffer = (float4*) (input_frag+4); - float4 *filter_frag_buffer = (float4*) (filter_frag+4); - - // Mainloop - iterates over the entire K dimension - not unrolled - for(int iter=0; iter>>(w, Ww, filt_k, filt_c, filt_h, filt_w); - - // each thread block will load 32 tiles (4x4) from the single image input - // we let X*Y = 32 and arbitraraly pick X = 4 and Y = 8 - Winograd_kernel<<>>(k, Ww, C, - tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); - - return cudaGetLastError(); -} - -// } - -template -static void conv_winograd_stage0_f32_cuda( - const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, - const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const T * src0, float * dst, - cudaStream_t stream) { - - FX<<>>(src0, dst, src0_ne3, src0_ne2, src0_ne1, src0_ne0); - -} - -static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, - int tile_size, int tile_2d_s, - const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, - const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, - const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const float * src0, const float * src1, float * dst, - cudaStream_t stream) { - - int64_t filt_k = src0_ne0; - int64_t in_c = src1_ne2; - int64_t in_h = src1_ne1; - int64_t in_w = src1_ne0; - int64_t filt_c = src0_ne3; - int64_t out_c = filt_k; - int64_t out_h = in_h; - int64_t out_w = in_w; - int smem_size = (16*BN*BC + 16*BC*BK)*4; - - Winograd_kernel<<>>(src1, src0, dst, - tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, - filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); -} - - -void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - // const half * src0_d = (const half *)src0->data; - // in case this tensor has already been computed in a preprocessing step, - // skip this time; - if(src0 == NULL){ - return; - } - - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - // int id = ggml_cuda_get_device(); - - // GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - // ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); - // if (src0->type != GGML_TYPE_F32) { - // const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); - // GGML_ASSERT(to_fp32_cuda != nullptr); - // int64_t nle = ggml_nelements(src0); - // src0_ddq_as_f32.alloc(nle); - // const char * src0_dd = (char *)src0->data; - // to_fp32_cuda(src0_dd, src0_ddq_as_f32.get(), nle, stream); - // } - - // // GGML_ASSERT(ggml_is_contiguous(src0)); - // const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get(); - if(src0->type == GGML_TYPE_F32){ - const float* src0_d = (const float *)src0->data; - // conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - conv_winograd_stage0_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - src0_d, dst_d, stream); - }else{ - const half * src0_d = (const half *)src0->data; - // conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - conv_winograd_stage0_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - src0_d, dst_d, stream); - } -} - - - -void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - - const ggml_tensor * src1 = dst->src[1]; - const float * src1_d = (const float *)src1->data; - - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const int m = 2; - const int r = 3; - const int tile_size = m+r-1; - int tiles_dim_w, tiles_dim_h; - - tiles_dim_w = ceil(ceil((double)(src1->ne[0]+2)/2)-1); - tiles_dim_h = ceil(ceil((double)(src1->ne[1]+2)/2)-1); - - int tile_2d_s = tile_size*tile_size; - - cudaMemcpyToSymbol(access_f_s, aux, 64*sizeof(int)); - cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int)); - cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int)); - - conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, - tile_size, tile_2d_s, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - src0_d, src1_d, dst_d, stream); - -} - - +#include "conv-winograd.cuh" +#include "convert.cuh" +#include + +#if 0 +__device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag, float4 accumulator[][16]){ + accumulator[0][0].x += input_frag[0].x*filter_frag[0].x; + accumulator[0][0].y += input_frag[0].y*filter_frag[0].x; + accumulator[0][0].z += input_frag[0].z*filter_frag[0].x; + accumulator[0][0].w += input_frag[0].w*filter_frag[0].x; + + accumulator[0][1].x += input_frag[1].x*filter_frag[0].x; + accumulator[0][1].y += input_frag[1].y*filter_frag[0].x; + accumulator[0][1].z += input_frag[1].z*filter_frag[0].x; + accumulator[0][1].w += input_frag[1].w*filter_frag[0].x; + + accumulator[0][2].x += input_frag[0].x*filter_frag[0].y; + accumulator[0][2].y += input_frag[0].y*filter_frag[0].y; + accumulator[0][2].z += input_frag[0].z*filter_frag[0].y; + accumulator[0][2].w += input_frag[0].w*filter_frag[0].y; + + accumulator[0][3].x += input_frag[1].x*filter_frag[0].y; + accumulator[0][3].y += input_frag[1].y*filter_frag[0].y; + accumulator[0][3].z += input_frag[1].z*filter_frag[0].y; + accumulator[0][3].w += input_frag[1].w*filter_frag[0].y; + + accumulator[0][4].x += input_frag[0].x*filter_frag[0].z; + accumulator[0][4].y += input_frag[0].y*filter_frag[0].z; + accumulator[0][4].z += input_frag[0].z*filter_frag[0].z; + accumulator[0][4].w += input_frag[0].w*filter_frag[0].z; + + accumulator[0][5].x += input_frag[1].x*filter_frag[0].z; + accumulator[0][5].y += input_frag[1].y*filter_frag[0].z; + accumulator[0][5].z += input_frag[1].z*filter_frag[0].z; + accumulator[0][5].w += input_frag[1].w*filter_frag[0].z; + + accumulator[0][6].x += input_frag[0].x*filter_frag[0].w; + accumulator[0][6].y += input_frag[0].y*filter_frag[0].w; + accumulator[0][6].z += input_frag[0].z*filter_frag[0].w; + accumulator[0][6].w += input_frag[0].w*filter_frag[0].w; + + accumulator[0][7].x += input_frag[1].x*filter_frag[0].w; + accumulator[0][7].y += input_frag[1].y*filter_frag[0].w; + accumulator[0][7].z += input_frag[1].z*filter_frag[0].w; + accumulator[0][7].w += input_frag[1].w*filter_frag[0].w; + + // + accumulator[0][8].x += input_frag[0].x*filter_frag[1].x; + accumulator[0][8].y += input_frag[0].y*filter_frag[1].x; + accumulator[0][8].z += input_frag[0].z*filter_frag[1].x; + accumulator[0][8].w += input_frag[0].w*filter_frag[1].x; + + accumulator[0][9].x += input_frag[1].x*filter_frag[1].x; + accumulator[0][9].y += input_frag[1].y*filter_frag[1].x; + accumulator[0][9].z += input_frag[1].z*filter_frag[1].x; + accumulator[0][9].w += input_frag[1].w*filter_frag[1].x; + + accumulator[0][10].x += input_frag[0].x*filter_frag[1].y; + accumulator[0][10].y += input_frag[0].y*filter_frag[1].y; + accumulator[0][10].z += input_frag[0].z*filter_frag[1].y; + accumulator[0][10].w += input_frag[0].w*filter_frag[1].y; + + accumulator[0][11].x += input_frag[1].x*filter_frag[1].y; + accumulator[0][11].y += input_frag[1].y*filter_frag[1].y; + accumulator[0][11].z += input_frag[1].z*filter_frag[1].y; + accumulator[0][11].w += input_frag[1].w*filter_frag[1].y; + + accumulator[0][12].x += input_frag[0].x*filter_frag[1].z; + accumulator[0][12].y += input_frag[0].y*filter_frag[1].z; + accumulator[0][12].z += input_frag[0].z*filter_frag[1].z; + accumulator[0][12].w += input_frag[0].w*filter_frag[1].z; + + accumulator[0][13].x += input_frag[1].x*filter_frag[1].z; + accumulator[0][13].y += input_frag[1].y*filter_frag[1].z; + accumulator[0][13].z += input_frag[1].z*filter_frag[1].z; + accumulator[0][13].w += input_frag[1].w*filter_frag[1].z; + + accumulator[0][14].x += input_frag[0].x*filter_frag[1].w; + accumulator[0][14].y += input_frag[0].y*filter_frag[1].w; + accumulator[0][14].z += input_frag[0].z*filter_frag[1].w; + accumulator[0][14].w += input_frag[0].w*filter_frag[1].w; + + accumulator[0][15].x += input_frag[1].x*filter_frag[1].w; + accumulator[0][15].y += input_frag[1].y*filter_frag[1].w; + accumulator[0][15].z += input_frag[1].z*filter_frag[1].w; + accumulator[0][15].w += input_frag[1].w*filter_frag[1].w; + + ////// + accumulator[1][0].x += input_frag[2].x*filter_frag[2].x; + accumulator[1][0].y += input_frag[2].y*filter_frag[2].x; + accumulator[1][0].z += input_frag[2].z*filter_frag[2].x; + accumulator[1][0].w += input_frag[2].w*filter_frag[2].x; + + accumulator[1][1].x += input_frag[3].x*filter_frag[2].x; + accumulator[1][1].y += input_frag[3].y*filter_frag[2].x; + accumulator[1][1].z += input_frag[3].z*filter_frag[2].x; + accumulator[1][1].w += input_frag[3].w*filter_frag[2].x; + + accumulator[1][2].x += input_frag[2].x*filter_frag[2].y; + accumulator[1][2].y += input_frag[2].y*filter_frag[2].y; + accumulator[1][2].z += input_frag[2].z*filter_frag[2].y; + accumulator[1][2].w += input_frag[2].w*filter_frag[2].y; + + accumulator[1][3].x += input_frag[3].x*filter_frag[2].y; + accumulator[1][3].y += input_frag[3].y*filter_frag[2].y; + accumulator[1][3].z += input_frag[3].z*filter_frag[2].y; + accumulator[1][3].w += input_frag[3].w*filter_frag[2].y; + + accumulator[1][4].x += input_frag[2].x*filter_frag[2].z; + accumulator[1][4].y += input_frag[2].y*filter_frag[2].z; + accumulator[1][4].z += input_frag[2].z*filter_frag[2].z; + accumulator[1][4].w += input_frag[2].w*filter_frag[2].z; + + accumulator[1][5].x += input_frag[3].x*filter_frag[2].z; + accumulator[1][5].y += input_frag[3].y*filter_frag[2].z; + accumulator[1][5].z += input_frag[3].z*filter_frag[2].z; + accumulator[1][5].w += input_frag[3].w*filter_frag[2].z; + + accumulator[1][6].x += input_frag[2].x*filter_frag[2].w; + accumulator[1][6].y += input_frag[2].y*filter_frag[2].w; + accumulator[1][6].z += input_frag[2].z*filter_frag[2].w; + accumulator[1][6].w += input_frag[2].w*filter_frag[2].w; + + accumulator[1][7].x += input_frag[3].x*filter_frag[2].w; + accumulator[1][7].y += input_frag[3].y*filter_frag[2].w; + accumulator[1][7].z += input_frag[3].z*filter_frag[2].w; + accumulator[1][7].w += input_frag[3].w*filter_frag[2].w; + + // + accumulator[1][8].x += input_frag[2].x*filter_frag[3].x; + accumulator[1][8].y += input_frag[2].y*filter_frag[3].x; + accumulator[1][8].z += input_frag[2].z*filter_frag[3].x; + accumulator[1][8].w += input_frag[2].w*filter_frag[3].x; + + accumulator[1][9].x += input_frag[3].x*filter_frag[3].x; + accumulator[1][9].y += input_frag[3].y*filter_frag[3].x; + accumulator[1][9].z += input_frag[3].z*filter_frag[3].x; + accumulator[1][9].w += input_frag[3].w*filter_frag[3].x; + + accumulator[1][10].x += input_frag[2].x*filter_frag[3].y; + accumulator[1][10].y += input_frag[2].y*filter_frag[3].y; + accumulator[1][10].z += input_frag[2].z*filter_frag[3].y; + accumulator[1][10].w += input_frag[2].w*filter_frag[3].y; + + accumulator[1][11].x += input_frag[3].x*filter_frag[3].y; + accumulator[1][11].y += input_frag[3].y*filter_frag[3].y; + accumulator[1][11].z += input_frag[3].z*filter_frag[3].y; + accumulator[1][11].w += input_frag[3].w*filter_frag[3].y; + + accumulator[1][12].x += input_frag[2].x*filter_frag[3].z; + accumulator[1][12].y += input_frag[2].y*filter_frag[3].z; + accumulator[1][12].z += input_frag[2].z*filter_frag[3].z; + accumulator[1][12].w += input_frag[2].w*filter_frag[3].z; + + accumulator[1][13].x += input_frag[3].x*filter_frag[3].z; + accumulator[1][13].y += input_frag[3].y*filter_frag[3].z; + accumulator[1][13].z += input_frag[3].z*filter_frag[3].z; + accumulator[1][13].w += input_frag[3].w*filter_frag[3].z; + + accumulator[1][14].x += input_frag[2].x*filter_frag[3].w; + accumulator[1][14].y += input_frag[2].y*filter_frag[3].w; + accumulator[1][14].z += input_frag[2].z*filter_frag[3].w; + accumulator[1][14].w += input_frag[2].w*filter_frag[3].w; + + accumulator[1][15].x += input_frag[3].x*filter_frag[3].w; + accumulator[1][15].y += input_frag[3].y*filter_frag[3].w; + accumulator[1][15].z += input_frag[3].z*filter_frag[3].w; + accumulator[1][15].w += input_frag[3].w*filter_frag[3].w; + } +#endif + +template +__device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag, float4 accumulator_in[][16]){ + T *accumulator = (T *)accumulator_in; + T *x = (T *)input_frag; + T *y = (T *)filter_frag; + + *(accumulator) += (*(x+0)) * (*y); + *(accumulator+1) += (*(x+1)) * (*(y)); + *(accumulator+2) += (*(x+2)) * (*(y)); + *(accumulator+3) += (*(x+3)) * (*(y)); + + *(accumulator+4) += (*(x+4)) * (*(y)); + *(accumulator+5) += (*(x+5)) * (*(y)); + *(accumulator+6) += (*(x+6)) * (*(y)); + *(accumulator+7) += (*(x+7)) * (*(y)); + + *(accumulator+8) += (*(x+0)) * (*(y+1)); + *(accumulator+9) += (*(x+1)) * (*(y+1)); + *(accumulator+10) += (*(x+2)) * (*(y+1)); + *(accumulator+11) += (*(x+3)) * (*(y+1)); + + *(accumulator+12) += (*(x+4)) * (*(y+1)); + *(accumulator+13) += (*(x+5)) * (*(y+1)); + *(accumulator+14) += (*(x+6)) * (*(y+1)); + *(accumulator+15) += (*(x+7)) * (*(y+1)); + + *(accumulator+16) += (*(x+0)) * (*(y+2)); + *(accumulator+17) += (*(x+1)) * (*(y+2)); + *(accumulator+18) += (*(x+2)) * (*(y+2)); + *(accumulator+19) += (*(x+3)) * (*(y+2)); + + *(accumulator+20) += (*(x+4)) * (*(y+2)); + *(accumulator+21) += (*(x+5)) * (*(y+2)); + *(accumulator+22) += (*(x+6)) * (*(y+2)); + *(accumulator+23) += (*(x+7)) * (*(y+2)); + + *(accumulator+24) += (*(x+0)) * (*(y+3)); + *(accumulator+25) += (*(x+1)) * (*(y+3)); + *(accumulator+26) += (*(x+2)) * (*(y+3)); + *(accumulator+27) += (*(x+3)) * (*(y+3)); + + *(accumulator+28) += (*(x+4)) * (*(y+3)); + *(accumulator+29) += (*(x+5)) * (*(y+3)); + *(accumulator+30) += (*(x+6)) * (*(y+3)); + *(accumulator+31) += (*(x+7)) * (*(y+3)); + + // + *(accumulator+32) += (*(x+0)) * (*(y+4)); + *(accumulator+33) += (*(x+1)) * (*(y+4)); + *(accumulator+34) += (*(x+2)) * (*(y+4)); + *(accumulator+35) += (*(x+3)) * (*(y+4)); + + *(accumulator+36) += (*(x+4)) * (*(y+4)); + *(accumulator+37) += (*(x+5)) * (*(y+4)); + *(accumulator+38) += (*(x+6)) * (*(y+4)); + *(accumulator+39) += (*(x+7)) * (*(y+4)); + + *(accumulator+40) += (*(x+0)) * (*(y+5)); + *(accumulator+41) += (*(x+1)) * (*(y+5)); + *(accumulator+42) += (*(x+2)) * (*(y+5)); + *(accumulator+43) += (*(x+3)) * (*(y+5)); + + *(accumulator+44) += (*(x+4)) * (*(y+5)); + *(accumulator+45) += (*(x+5)) * (*(y+5)); + *(accumulator+46) += (*(x+6)) * (*(y+5)); + *(accumulator+47) += (*(x+7)) * (*(y+5)); + + *(accumulator+48) += (*(x+0)) * (*(y+6)); + *(accumulator+49) += (*(x+1)) * (*(y+6)); + *(accumulator+50) += (*(x+2)) * (*(y+6)); + *(accumulator+51) += (*(x+3)) * (*(y+6)); + + *(accumulator+52) += (*(x+4)) * (*(y+6)); + *(accumulator+53) += (*(x+5)) * (*(y+6)); + *(accumulator+54) += (*(x+6)) * (*(y+6)); + *(accumulator+55) += (*(x+7)) * (*(y+6)); + + *(accumulator+56) += (*(x+0)) * (*(y+7)); + *(accumulator+57) += (*(x+1)) * (*(y+7)); + *(accumulator+58) += (*(x+2)) * (*(y+7)); + *(accumulator+59) += (*(x+3)) * (*(y+7)); + + *(accumulator+60) += (*(x+4)) * (*(y+7)); + *(accumulator+61) += (*(x+5)) * (*(y+7)); + *(accumulator+62) += (*(x+6)) * (*(y+7)); + *(accumulator+63) += (*(x+7)) * (*(y+7)); + + ////// + + *(accumulator+64) += (*(x+8)) * (*(y+8)); + *(accumulator+65) += (*(x+9)) * (*(y+8)); + *(accumulator+66) += (*(x+10)) * (*(y+8)); + *(accumulator+67) += (*(x+11)) * (*(y+8)); + *(accumulator+68) += (*(x+12)) * (*(y+8)); + *(accumulator+69) += (*(x+13)) * (*(y+8)); + *(accumulator+70) += (*(x+14)) * (*(y+8)); + *(accumulator+71) += (*(x+15)) * (*(y+8)); + + *(accumulator+72) += (*(x+8)) * (*(y+9)); + *(accumulator+73) += (*(x+9)) * (*(y+9)); + *(accumulator+74) += (*(x+10)) * (*(y+9)); + *(accumulator+75) += (*(x+11)) * (*(y+9)); + *(accumulator+76) += (*(x+12)) * (*(y+9)); + *(accumulator+77) += (*(x+13)) * (*(y+9)); + *(accumulator+78) += (*(x+14)) * (*(y+9)); + *(accumulator+79) += (*(x+15)) * (*(y+9)); + + *(accumulator+80) += (*(x+8)) * (*(y+10)); + *(accumulator+81) += (*(x+9)) * (*(y+10)); + *(accumulator+82) += (*(x+10)) * (*(y+10)); + *(accumulator+83) += (*(x+11)) * (*(y+10)); + *(accumulator+84) += (*(x+12)) * (*(y+10)); + *(accumulator+85) += (*(x+13)) * (*(y+10)); + *(accumulator+86) += (*(x+14)) * (*(y+10)); + *(accumulator+87) += (*(x+15)) * (*(y+10)); + + *(accumulator+88) += (*(x+8)) * (*(y+11)); + *(accumulator+89) += (*(x+9)) * (*(y+11)); + *(accumulator+90) += (*(x+10)) * (*(y+11)); + *(accumulator+91) += (*(x+11)) * (*(y+11)); + *(accumulator+92) += (*(x+12)) * (*(y+11)); + *(accumulator+93) += (*(x+13)) * (*(y+11)); + *(accumulator+94) += (*(x+14)) * (*(y+11)); + *(accumulator+95) += (*(x+15)) * (*(y+11)); + + // + + *(accumulator+96) += (*(x+8)) * (*(y+12)); + *(accumulator+97) += (*(x+9)) * (*(y+12)); + *(accumulator+98) += (*(x+10)) * (*(y+12)); + *(accumulator+99) += (*(x+11)) * (*(y+12)); + *(accumulator+100) += (*(x+12)) * (*(y+12)); + *(accumulator+101) += (*(x+13)) * (*(y+12)); + *(accumulator+102) += (*(x+14)) * (*(y+12)); + *(accumulator+103) += (*(x+15)) * (*(y+12)); + + *(accumulator+104) += (*(x+8)) * (*(y+13)); + *(accumulator+105) += (*(x+9)) * (*(y+13)); + *(accumulator+106) += (*(x+10)) * (*(y+13)); + *(accumulator+107) += (*(x+11)) * (*(y+13)); + *(accumulator+108) += (*(x+12)) * (*(y+13)); + *(accumulator+109) += (*(x+13)) * (*(y+13)); + *(accumulator+110) += (*(x+14)) * (*(y+13)); + *(accumulator+111) += (*(x+15)) * (*(y+13)); + + *(accumulator+112) += (*(x+8)) * (*(y+14)); + *(accumulator+113) += (*(x+9)) * (*(y+14)); + *(accumulator+114) += (*(x+10)) * (*(y+14)); + *(accumulator+115) += (*(x+11)) * (*(y+14)); + *(accumulator+116) += (*(x+12)) * (*(y+14)); + *(accumulator+117) += (*(x+13)) * (*(y+14)); + *(accumulator+118) += (*(x+14)) * (*(y+14)); + *(accumulator+119) += (*(x+15)) * (*(y+14)); + + *(accumulator+120) += (*(x+8)) * (*(y+15)); + *(accumulator+121) += (*(x+9)) * (*(y+15)); + *(accumulator+122) += (*(x+10)) * (*(y+15)); + *(accumulator+123) += (*(x+11)) * (*(y+15)); + *(accumulator+124) += (*(x+12)) * (*(y+15)); + *(accumulator+125) += (*(x+13)) * (*(y+15)); + *(accumulator+126) += (*(x+14)) * (*(y+15)); + *(accumulator+127) += (*(x+15)) * (*(y+15)); + + } + +// extern "C" +// { + +__device__ __forceinline__ void transform_output_tile(float * __restrict__ pOutputs, float2 *C_tile, float2 *At, + int round, int c_tensor, int c_glb_offset, int i1, int i2, + unsigned int mask1, unsigned int mask2, int out_w) +{ + + c_tensor += (((round)>>1)*32 + ((round)&1)*2)*c_glb_offset; + int x, x1; + + #pragma unroll + for(int j=0; j<4; j++){ + + At[j].x = C_tile[j].x + C_tile[4+j].x + C_tile[8+j].x; + At[j].y = C_tile[j].y + C_tile[4+j].y + C_tile[8+j].y; + + At[4+j].x = C_tile[4+j].x - C_tile[8+j].x - C_tile[12+j].x; + At[4+j].y = C_tile[4+j].y - C_tile[8+j].y - C_tile[12+j].y; + + } + + #pragma unroll + for(int i=0; i<2; i++){ + x = i*4; + x1 = i*((out_w-(out_w&1)) + (out_w&1)/2); + + if(mask1&(1<<(i*2))){ + pOutputs[x1 + c_tensor + i1] = At[x].x + At[x+1].x + At[x+2].x; + } + if(mask2&(1<<(i*2))){ + pOutputs[x1 + c_tensor + i2] = At[x].y + At[x+1].y + At[x+2].y; + } + if(mask1&(1<<(i*2+1))){ + pOutputs[x1 + c_tensor + i1 + 1] = At[x+1].x - At[x+2].x - At[x+3].x; + } + if(mask2&(1<<(i*2+1))){ + pOutputs[x1 + c_tensor + i2 + 1] = At[x+1].y - At[x+2].y - At[x+3].y; + } + } +} + +__device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int tiles_dim_h, + int tw, int th, int out_w, int out_h){ + + unsigned int mask = 0x000F; + // if((blockIdx.y/tiles_dim)==(tiles_dim-1) && out_w%2) mask&=0x0003; // pad bottom row + // if(!((blockIdx.y+1)%tiles_dim) && out_w%2) mask&=0X0005; // pad right col + // if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row + // if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col + if((tiles_dim_w & (tw-1)) == 0 && (tiles_dim_h & (th-1)) == 0){ + if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && (out_h&1)) mask&=0x0003; // pad bottom row + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) == tw-1 && (out_w&1)) mask&=0X0005; // pad right col + }else if((tiles_dim_w & (tw-1)) == 0){ + int k = out_h & (TH-1); + int k1 = k & 1 ? (k+1)>>1 : k>>1; // there could be 4*k1 tiles + if(blockIdx.y==gridDim.y-1 && (idd / tw) == k1-1 && (k&1)) mask&=0x0003; // pad bottom row + if(blockIdx.y==gridDim.y-1 && (idd / tw) > k1-1) mask &= 0x0; //pad all zeros since this tile does not exist + }else if((tiles_dim_h & (th-1)) == 0){ + int k = out_w & (TW-1); + int k1 = k & 1 ? (k+1) >> 1 : k >> 1; // there could be 4*k1 tiles + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) == k1-1 && (k&1)) mask&=0X0005; // pad right col + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) > k1-1) mask&=0X0; // pad all zeroes + }else{ + int kh = out_h & (TH-1); + int kw = out_w & (TW-1); + int kh1 = kh & 1 ? (kh+1) >> 1 : kh >> 1; // there could be kh1*kw1 tiles + int kw1 = kw & 1 ? (kw+1) >> 1 : kw >> 1; + if(blockIdx.y==gridDim.y-1 && (idd / tw) == kh1-1 && (kh&1)) mask&=0x0003; // pad bottom row + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) == kw1-1 && (kw&1)) mask&=0X0005; // pad right col + if(blockIdx.y==gridDim.y-1 && (idd / tw) > kh1-1) mask &= 0x0; //pad all zeros since this tile does not exist + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) > kw1-1) mask &= 0X0; // pad all zeroes + } + return mask; +} + +__device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float * __restrict__ C, +int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, +float4 *input_frag_mem, float4* filter_frag_mem){ + + float2 *output_smem = (float2 *) shared_mem; + float2 *accumulator = (float2 *) acumm_smem; + // float2 *C_out = (float2*)C; + + float2 *C_tile = (float2*) input_frag_mem; + float2 *At = (float2*) filter_frag_mem; + // for output transformation, the role of threadIdx.y changes again: + // in the main loop, different threadIdx.y deal with different element of the 4x4 tile + // here, they are for 4 different groups of lane ids from optSTS64 layout + // for init (and init+32), we need to identify its tile number (0-31) within the supertile + // first, from init, find out from which threadIdx.x it comes. + + // now we got l, which is the land id which computed accumulated sum for the tile element + // each lane id (or threadIdx.x) computed 8 tiles which are distributed into 4 locations spreading + // over the smem. We need to find which of the 8 the current tile is. + // use tileid table to figure out + + // for 2nd tile + + int idd1 = tileid[0][threadIdx.x]; + int id1 = (idd1 & (tw-1)) * 2 + (idd1 / tw) * out_w * 2; + int idd2 = tileid[1][threadIdx.x]; + int id2 = (idd2 & (tw-1)) * 2 + (idd2 / tw) * out_w * 2; + + // unsigned short mask1 = 0x000F; + unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + + // output transpose step + int t=0; + int acumm1, acumm2; + // For transposing + //acumm1 = access_s_out[Inx]; //* 4 + acumm1 = ((threadIdx.x&7)>>1)*34 + (threadIdx.x&1) + (threadIdx.x>>4)*2 + ((threadIdx.x>>3)&1)*8; + acumm2 = acumm1+4; + + int acumm4 = BN_p*16 ; //*4 + int idx = threadIdx.y * BN_p; + int idx2 = idx + BN_p*8; //(BN_p*2 *8)/2 + + // For transformating + int offset = BN_p *2; //*2/2 + int init = ( (threadIdx.y>>2)*BN_p*16 + (threadIdx.y&3)*(32+2) ) *2 + threadIdx.x; + + int c_glb_offset = out_h*out_w; + // int c_tensor = blockIdx.z*c_glb_offset*BK + (blockIdx.y%tiles_dim)*2 + (blockIdx.y/tiles_dim)*out_w*2 + + // blockIdx.x*BN + (threadIdx.x%16)*2+ + // ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; + + // int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; + // int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + + // threadIdx.y*(in_h*in_w) - (in_w+1); + + int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * TW + blockIdx.y * out_w * TH + + // (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 + + ((threadIdx.x>>4)*16 + (threadIdx.y&3)*4 + (threadIdx.y>>2))*c_glb_offset; + + #pragma unroll + for(int round=0; round<4; round++){ + + *( (float2*) (output_smem + idx + acumm1) ) = *(accumulator+t); + *( (float2*) (output_smem + idx + acumm1 + 16) ) = *(accumulator+t+1); // float 4, t + *( (float2*) (output_smem + idx + acumm2) ) = *(accumulator+t+2); + *( (float2*) (output_smem + idx + acumm2 + 16) ) = *(accumulator+t+3); // float 4, t+1 + + + *( (float2*) (output_smem + idx2 + acumm1) ) = *(accumulator+t+32); + *( (float2*) (output_smem + idx2 + acumm1 + 16) ) = *(accumulator+t+33); // float 4, t+16 + *( (float2*) (output_smem + idx2 + acumm2) ) = *(accumulator+t+34); + *( (float2*) (output_smem + idx2 + acumm2 + 16) ) = *(accumulator+t+35); // float 4, t+17 + + // the above 8 float2 will be consumed by theadIdx.y = [0,1,2,3] + + // the following 8 float2 will be consumed by theadIdx.y = [4,5,6,7] + + *( (float2*) (output_smem + idx + acumm4 + acumm1) ) = *(accumulator+t+4); + *( (float2*) (output_smem + idx + acumm4 + acumm1 + 16) ) = *(accumulator+t+5); // float 4, t+2 + *( (float2*) (output_smem + idx + acumm4 + acumm2) ) = *(accumulator+t+6); + *( (float2*) (output_smem + idx + acumm4 + acumm2 + 16) ) = *(accumulator+t+7); // float 4, t+3 + + *( (float2*) (output_smem + idx2 + acumm4 + acumm1) ) = *(accumulator+t+36); + *( (float2*) (output_smem + idx2 + acumm4 + acumm1 + 16) ) = *(accumulator+t+37); // float 4, t+18 + *( (float2*) (output_smem + idx2 + acumm4 + acumm2) ) = *(accumulator+t+38); + *( (float2*) (output_smem + idx2 + acumm4 + acumm2 + 16) ) = *(accumulator+t+39); // float 4, t+19 + + + + t+=8; + + __syncthreads(); + + + + + for(int i=0; i<16; i++){ + C_tile[i].x = shared_mem[i*offset + init]; + C_tile[i].y = shared_mem[i*offset + init + 32]; + + } + + + // transform output tiles + transform_output_tile(C, C_tile, At, round, c_tensor, c_glb_offset, id1, id2, mask1, mask2, out_w); + __syncthreads(); + } +} + +template + static __device__ T __forceinline__ fx_const (float val) { + return static_cast(val); + } +template <> + __device__ half __forceinline__ fx_const(float val) { + return __float2half(val); + } + +// Set of functions per row in Gw product +template + __device__ T f_row1(const T * __restrict__ G, int j){ + return G[j]; + } +template + __device__ T f_row2(const T * __restrict__ G, int j){ + return fx_const(0.5f)*(G[j] + G[6+j] + G[3+j]); + } +template + __device__ T f_row3(const T * __restrict__ G, int j){ + return fx_const(0.5f)*(G[j] + G[6+j] - G[3+j]); + } +template + __device__ T f_row4(const T * __restrict__ G, int j){ + return G[6+j]; + } + + // Set of functions per column in GwGt product +template + __device__ T f_col1(const T * __restrict__ G, int j){ + return G[j]; + } +template + __device__ T f_col2(const T * __restrict__ G, int j){ + return fx_const(0.5f)*(G[j] + G[j+2] + G[j+1]); + } +template + __device__ T f_col3(const T * __restrict__ G, int j){ + return fx_const(0.5f)*(G[j] + G[j+2] - G[j+1]); + } +template + __device__ T f_col4(const T * __restrict__ G, int j){ + return G[j+2]; + } + + template + static __device__ __forceinline__ float t2f32(T val) { + return (float) val; + } + + template <> + __device__ float __forceinline__ t2f32(half val) { + return __half2float(val); + } + + + + template + __global__ void FX(const T * __restrict__ pInputs, T * __restrict__ pOutputs, int filt_k, + int filt_c, int filt_h, int filt_w){ + + typedef T(*pointFunction_t)(const T *, int); + + // assumes KCHW layout + int Inx = threadIdx.x, Iny = threadIdx.y; + int TileX = blockIdx.x, TileY = blockIdx.y; + + // int c_glb_offset = filt_k*filt_h*filt_w; + // int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx; + int c_glb_offset = filt_h*filt_w; + // int c_kernel = TileY*BC*c_glb_offset + TileX*BK*filt_c*c_glb_offset + Iny*c_glb_offset+ Inx*filt_c*c_glb_offset; + int c_kernel = (TileY*BC + (TileX*BK+Inx)*filt_c + Iny)*c_glb_offset; + int c_glb_offset_s = filt_k*4*4; + int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx; + + T Gw[21]; //9+12. In registers + T *Gw_buffer = Gw+9; + + pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4}; + pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; + + for(int bk=0; bk= BC) return; + + // each thread in row 0 puts its first element of 1st filter tile(loaded by the thread) in smem + // taking 32 slots + // then puts its first element of 2nd filter tile immediately after, taking another 32 slots + // then followed by threads in row 1, 2.. until 7 + + // Note the next element is BK*BC (8*64) slots away, then another BK*BC .... + // for every 64 values, the first 32 belongs to filter tile 1, the next 32 for filter tile 2 + + for(int k=0; k<2; k++){ // prefetch 2 filter tiles/thread + for(int i=0; i<4; i++){ + #pragma unroll + for(int j=0; j<4; j++){ + pOutputs[c_tensor_s + i*c_offset_s*4 + j*c_offset_s] = tiles[k*16 + i*4 + j]; + } + } + // 2nd tile right behind the 1st? + c_tensor_s += BN; // BN has nothing to do with input tiles + } + +} + +template +__device__ __forceinline__ void prefetch_filter_tile(const T * __restrict__ pInputs, float * __restrict__ tiles, int filt_k){ + + int c_tensor = blockIdx.z*BK + threadIdx.y*(filt_k<<4) + threadIdx.x; // Iny*filt_k*4*4 + // each threadIdx.y corresponds to one channel; there are 8 different threadIdx.y so 8 channels + + //each thread (32 threads in x direction) loads 2 kernel tiles (32 in K direction apart) + // save the two tiles in a float[32] register, float[16] for each + + int acumm; + #pragma unroll + for(int i=0; i<4; i++){ + acumm = (i*filt_k<<2); + #pragma unroll + for(int j=0; j<4; j++){ + tiles[(i<<2) + j] = t2f32(pInputs[acumm + j*filt_k + c_tensor]); + tiles[16 + (i<<2) + j] = t2f32(pInputs[acumm + j*filt_k + c_tensor+BN]); + } + } +} + +__device__ __forceinline__ void prefetch_input_tile(const float * __restrict__ pInputs, float *tile, int in_h, + int in_w, int tw, int th, unsigned short mask){ + + // each thread loads two input tiles to fill a half2 buffer + int c_offset = in_h*in_w; + int c_tile = blockIdx.x * TW + blockIdx.y * in_w * TH; + int c_tensor = c_tile + ((threadIdx.x & (tw-1)) << 1) + (threadIdx.x / tw ) * (in_w << 1) + + threadIdx.y*c_offset - (in_w+1); + + int acumm,x; + + + if(mask==0xFFFF){ + #pragma unroll + for(int i=0; i<4; i++){ + acumm = i*in_w; + #pragma unroll + for(int j=0; j<4; j++){ + tile[(i<<2) + j] = pInputs[acumm + j + c_tensor]; + } + } + + } else { + for(int i=0; i<4; i++){ + acumm = i*in_w; + #pragma unroll + for(int j=0; j<4; j++){ + x = (i<<2) + j; + tile[x] = 0.f; + if(mask&(1< +__global__ void Winograd_kernel(const float *A, const T *B, float *C, + int tiles_dim_w, int tiles_dim_h, + int in_c, int in_h, int in_w, + int tile_size, int X, int Y, + int filt_k, int filt_c, + int out_c, + int tile_2d_s, int out_h, int out_w){ + + extern __shared__ float shared_mem[]; + float *input_smem = (float*)shared_mem; + float *filter_smem = (float*)&shared_mem[16*BC*BN]; + + unsigned int m = 0xFFFF; + + if(blockIdx.y==0 && (threadIdx.x / X) == 0) m &= 0xFFF0; // pad top row + if(tiles_dim_w & (X-1) == 0 && (tiles_dim_h & (Y-1)) == 0){ + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h&1))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) == X-1) m &= (!(in_w&1))?(0x7777):(0x3333); // pad right col or right 2 cols + }else if((tiles_dim_w & (X-1)) == 0){ + int k = in_h & (TH-1); + int k1 = k & 1 ? (k+1)>>1 : (k>>1); // there could be 4*k1 tiles + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) == X-1) m &= (!(in_w&1))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == k1-1) m &= (!(k&1))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X > k1-1) m &= 0x0; //pad all zeros since this tile does not exist + }else if((tiles_dim_h & (Y-1)) == 0){ + int k = in_w & (TW-1); + int k1 = k & 1 ? (k+1)>>1 : k>>1; // there could be 8*k1 tiles + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h&1))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) == k1-1) m &= (!(k&1))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) > k1-1) m &= 0x0; //pad all zeros since this tile does not exist + }else{ + int kh = in_h & (TH-1); + int kw = in_w & (TW-1); + int kh1 = kh & 1 ? (kh+1)>>1 : kh>>1; // there could be kh1*kw1 tiles + int kw1 = kw & 1 ? (kw+1)>>1 : kw>>1; + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == kh1-1) m &= (!(kh&1))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X > kh1-1) m &= 0x0; //pad all zeros since this tile does not exist + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) == kw1-1) m &= (!(kw&1))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) > kw1-1) m &= 0x0; //pad all zeros since this tile does not exist + } + if(blockIdx.x==0 && (threadIdx.x & (X-1)) == 0) m &=0xeeee; // pad left col + + float img_tile[16]; // Prefetch input from GMEM + float filter_tile[32]; // Prefetch filter from GMEM + + float4 input_frag_mem[8]; //2*2(2*8/4) Data to do Outer Product + prefetch f. SMEM (double_buffer) + float4 filter_frag_mem[8]; //2*2 Data to do Outer Product + prefetch f. SMEM (double_buffer) + float4 accumulator[2][16] = {0.0f}; // Accumulators + + float4 *A_frag; // Input data pointer + int frag_offset = 2 * (BN*BC); // (2=8/4) SMEM input read offset + + float4 *B_frag; // Filter data pointer + int f_frag_offset = 2 * (BC*BK); // (2=8/4 with 4 being float4) SMEM filter read offset + + + float4 *input_frag = (float4*) input_frag_mem; + float4 *filter_frag = (float4*) filter_frag_mem; + + float4 *swap_filter; + float4 *swap_input; + + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_filter_tile(B, filter_tile, filt_k); + + float4 *input_frag_buffer = (float4*) (input_frag+4); + float4 *filter_frag_buffer = (float4*) (filter_frag+4); + + // Mainloop - iterates over the entire K dimension - not unrolled + for(int iter=0; iter(input_frag, filter_frag, accumulator); + + swap_input = input_frag; + input_frag = input_frag_buffer; + input_frag_buffer = swap_input; + + swap_filter = filter_frag; + filter_frag = filter_frag_buffer; + filter_frag_buffer = swap_filter; + + } + + A += BC*in_w*in_h; + B += filt_k*BC*4*4; + + if(iter<(in_c-BC)){ + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_filter_tile(B, filter_tile, filt_k); + } + + __syncthreads(); + } + + // Transpose, transform and store accumulated result + store_output_tile(accumulator, shared_mem, C, out_h, out_w, tiles_dim_w, tiles_dim_h, X, Y, + input_frag_mem, filter_frag_mem); + +} + + +// } + +template +static void conv_winograd_stage0_cuda( + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const T * src0, T * dst, + cudaStream_t stream) { + // printf("doing FX\n"); + FX<<>>(src0, dst, src0_ne3, src0_ne2, src0_ne1, src0_ne0); + +} + +static void conv_winograd_stage1_f16_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, + int tile_size, int tile_2d_s, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const half * src0, const float * src1, float * dst, + cudaStream_t stream) { + + int64_t filt_k = src0_ne0; + int64_t in_c = src1_ne2; + int64_t in_h = src1_ne1; + int64_t in_w = src1_ne0; + int64_t filt_c = src0_ne3; + int64_t out_c = filt_k; + int64_t out_h = in_h; + int64_t out_w = in_w; + int smem_size = (16*BN*BC + 16*BC*BK)*4; + int max_size = 65536; // 64 KB + cudaFuncSetAttribute(Winograd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); + + Winograd_kernel<<>>(src1, src0, dst, + tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, + filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); +} + +static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, + int tile_size, int tile_2d_s, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst, + cudaStream_t stream) { + + int64_t filt_k = src0_ne0; + int64_t in_c = src1_ne2; + int64_t in_h = src1_ne1; + int64_t in_w = src1_ne0; + int64_t filt_c = src0_ne3; + int64_t out_c = filt_k; + int64_t out_h = in_h; + int64_t out_w = in_w; + int smem_size = (16*BN*BC + 16*BC*BK)*4; + int max_size = 65536; // 64 KB + cudaFuncSetAttribute(Winograd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); + + Winograd_kernel<<>>(src1, src0, dst, + tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, + filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); +} + + +void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + // const half * src0_d = (const half *)src0->data; + if(src0 == NULL){ + return; + } + // float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + // int id = ggml_cuda_get_device(); + + // GGML_ASSERT(src0->type == GGML_TYPE_F16); + // GGML_ASSERT( dst->type == GGML_TYPE_F32); + + // ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); + // if (src0->type != GGML_TYPE_F32) { + // const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + // GGML_ASSERT(to_fp32_cuda != nullptr); + // int64_t nle = ggml_nelements(src0); + // src0_ddq_as_f32.alloc(nle); + // const char * src0_dd = (char *)src0->data; + // to_fp32_cuda(src0_dd, src0_ddq_as_f32.get(), nle, stream); + // } + + // // GGML_ASSERT(ggml_is_contiguous(src0)); + // const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get(); + if(src0->type == GGML_TYPE_F32){ + const float* src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + // conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + conv_winograd_stage0_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, dst_d, stream); + }else{ + const half * src0_d = (const half *)src0->data; + half * dst_d = (half *)dst->data; + // conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + conv_winograd_stage0_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, dst_d, stream); + } +} + + + +void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + // const float * src0_d = (const float *)src0->data; + + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + // GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int m = 2; + const int r = 3; + const int tile_size = m+r-1; + int tiles_dim_w, tiles_dim_h; + + tiles_dim_w = ceil(ceil((double)(src1->ne[0]+2)/2)-1); + tiles_dim_h = ceil(ceil((double)(src1->ne[1]+2)/2)-1); + + int tile_2d_s = tile_size*tile_size; + + cudaMemcpyToSymbol(access_f_s, aux, 64*sizeof(int)); + cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int)); + cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int)); + + if(src0->type == GGML_TYPE_F32){ + const float * src0_d = (const float *)src0->data; + // const float * src1_d = (const float *)src1->data; + conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, + tile_size, tile_2d_s, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); + } else{ + const half * src0_d = (const half *)src0->data; + // const half * src1_d = (const half *)src1->data; + conv_winograd_stage1_f16_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, + tile_size, tile_2d_s, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); + } + + +} + + diff --git a/src/ggml-cuda/conv-winograd.cuh b/src/ggml-cuda/conv-winograd.cuh index ebc3f48fa..be9230c02 100644 --- a/src/ggml-cuda/conv-winograd.cuh +++ b/src/ggml-cuda/conv-winograd.cuh @@ -1,44 +1,44 @@ -#include "common.cuh" - - -#define BC 8 -#define BN 32 -#define BK 64 -#define TW 32 -#define TH 4 -#define BN_p 138 - -__constant__ int access_f_s[2][32]; -__constant__ int access_s[2][32]; -__constant__ int tileid[2][32]; - - -// access_f_s -const int aux[2][32] = { - {0,0,1,1,2,2,3,3,4,4,5,5,6,6, - 7,7,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7}, - {8,8,9,9,10,10,11,11,12,12,13,13, - 14,14,15,15,8,8,9,9,10,10,11,11,12,12, - 13,13,14,14,15,15} - }; -// access_s -const int aux2[2][32] = { - {0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,2, - 3,2,3,2,3,2,3,2,3,2,3,2,3,2,3}, - {4,5,4,5,4,5,4,5,4,5,4, - 5,4,5,4,5,6,7,6,7,6,7,6,7, - 6,7,6,7,6,7,6,7} - }; - -const int tid[2][32] = { - {0,1,4,5,8,9,12,13,16,17,20,21,24,25,28,29, - 0,1,4,5,8,9,12,13,16,17,20,21,24,25,28,29}, - {2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31, - 2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31} - }; - - - -void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); -void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst); - +#include "common.cuh" +#include + +#define BC 8 +#define BN 32 +#define BK 64 +#define TW 32 +#define TH 4 +#define BN_p 138 + +__constant__ int access_f_s[2][32]; +__constant__ int access_s[2][32]; +__constant__ int tileid[2][32]; + + +// access_f_s +const int aux[2][32] = { + {0,0,1,1,2,2,3,3,4,4,5,5,6,6, + 7,7,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7}, + {8,8,9,9,10,10,11,11,12,12,13,13, + 14,14,15,15,8,8,9,9,10,10,11,11,12,12, + 13,13,14,14,15,15} + }; +// access_s +const int aux2[2][32] = { + {0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,2, + 3,2,3,2,3,2,3,2,3,2,3,2,3,2,3}, + {4,5,4,5,4,5,4,5,4,5,4, + 5,4,5,4,5,6,7,6,7,6,7,6,7, + 6,7,6,7,6,7,6,7} + }; + +const int tid[2][32] = { + {0,1,4,5,8,9,12,13,16,17,20,21,24,25,28,29, + 0,1,4,5,8,9,12,13,16,17,20,21,24,25,28,29}, + {2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31, + 2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31} + }; + + + +void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp index 0c08c6817..ae5a4111d 100644 --- a/tests/test-conv2d-winograd.cpp +++ b/tests/test-conv2d-winograd.cpp @@ -37,8 +37,8 @@ struct test_model { void load_model(test_model & model, bool use_gpu = false) { // create data - int KW = 3, KH = 3, IC = 256, OC = 256; - int IW = 832, IH = 1216, N = 1; + int KW = 3, KH = 3, IC = 512, OC = 512; + int IW = 64, IH = 96, N = 1; printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH);