From 5b4e44828890bf14dd640e7abc9543bf0a6f5b0d Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 26 Sep 2024 18:04:04 -0400 Subject: [PATCH 01/23] added two winograd ops --- include/ggml.h | 2 ++ src/ggml.c | 70 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/include/ggml.h b/include/ggml.h index f7e5cfc38..ce8368947 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -510,6 +510,8 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_WINOGRAD_STAGE0, + GGML_OP_WINOGRAD_STAGE1, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, diff --git a/src/ggml.c b/src/ggml.c index 4b782b0c1..37d99868f 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -2995,6 +2995,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "WINOGRAD_STAGE0", + "WINOGRAD_STAGE1", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -3089,6 +3091,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", + "winograd_stage0(x)", + "winograd_stage1(x)", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -3118,7 +3122,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -7166,6 +7170,70 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( return result; } + +// ggml_winograd + +// a: [OC,IC, 3, 3] +// result: [OC, IC, 16] +struct ggml_tensor * ggml_winograd_stage0( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + GGML_ASSERT(a->ne[0] == 3 && a->ne[1] == 3); // kernel should be 3x3 + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 16, a->ne[2], a->ne[3], 1); + + result->op = GGML_OP_WINOGRAD_STAGE0; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_winograd +// a: [OC, IC, 4, 4] +// b: [1, IC, IH, IW] +// result: [N, OC, OH, OW] +struct ggml_tensor * ggml_winograd_stage1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + bool is_node = false; + if (a->grad) { + is_node = true; + } + + int OW = b->ne[0]; + int OH = b->ne[1]; + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, OW, OH, a->ne[3] /* OC */, 1); + + result->op = GGML_OP_WINOGRAD_STAGE1; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_conv_2d_3x3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b){ + + + GGML_ASSERT(b->ne[3] == 1); // only works for 1 input image + + struct ggml_tensor* W = ggml_winograd_stage0(ctx, a); + struct ggml_tensor * result = ggml_winograd_stage1(ctx, W, b); + + return result; + +} + + // ggml_pool_* static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) { From 68c251bf4387880da2d96d2149aa9a169cd5374f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 26 Sep 2024 22:02:42 -0400 Subject: [PATCH 02/23] added more checking for precondition to use winograd; switch to im2co if not satisfied --- src/ggml.c | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ggml.c b/src/ggml.c index 37d99868f..65d8c79a0 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -7223,8 +7223,11 @@ struct ggml_tensor * ggml_conv_2d_3x3( struct ggml_tensor * a, struct ggml_tensor * b){ - GGML_ASSERT(b->ne[3] == 1); // only works for 1 input image + GGML_ASSERT(b->ne[2] == a->ne[2]); // number of channels must match + if(a->ne[3] % 64 != 0 || a->ne[2] % 8 != 0) // only works for the number of filters is a multiple of 64 + return ggml_conv_2d(ctx, a, b, 1, 1, 1, 1, 1, 1); // and the number of channels is a multiple of 8 + struct ggml_tensor* W = ggml_winograd_stage0(ctx, a); struct ggml_tensor * result = ggml_winograd_stage1(ctx, W, b); From 2ccc67da6dddff91c006b97787c5c557baa9a650 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 27 Sep 2024 08:58:15 -0400 Subject: [PATCH 03/23] added source code for winograd kernel --- src/ggml-cuda/conv-winograd.cu | 124 ++++++ src/ggml-cuda/conv-winograd.cuh | 767 ++++++++++++++++++++++++++++++++ 2 files changed, 891 insertions(+) create mode 100644 src/ggml-cuda/conv-winograd.cu create mode 100644 src/ggml-cuda/conv-winograd.cuh diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu new file mode 100644 index 000000000..2ff57f822 --- /dev/null +++ b/src/ggml-cuda/conv-winograd.cu @@ -0,0 +1,124 @@ +#include "conv-transpose-1d.cuh" + +static __global__ void conv_transpose_1d_kernel( + const int s0, const int p0, const int d0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst) { + int global_index = threadIdx.x + blockIdx.x * blockDim.x; + if (global_index >= output_size) { + return; + } + + int out_index = global_index / dst_ne0; + + float accumulator = 0; + + for (int c = 0; c < src0_ne2; c++) { + int idx = global_index % dst_ne0; + + int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0); + int input_offset = src1_ne0 * c; + + for (int i = 0; i < src1_ne0; i++) { + if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) { + continue; + } + int weight_idx = idx - i*s0; + + float kernel_weight = src0[kernel_offset + weight_idx]; + float input_value = src1[input_offset+i]; + + accumulator += kernel_weight * input_value; + } + } + dst[global_index] = accumulator; +} + +static void conv_transpose_1d_f32_f32_cuda( + const int s0, const int p0, const int d0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst, + cudaStream_t stream) { + + const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE; + conv_transpose_1d_kernel<<>>( + s0,p0,d0,output_size, + src0_ne0, src0_ne1, src0_ne2, src0_ne3, + src1_ne0, src1_ne1, src1_ne2, src1_ne3, + dst_ne0, dst_ne1, dst_ne2, dst_ne3, + src0,src1, dst); +} + + +void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int32_t * opts = (const int32_t *)dst->op_params; + + const int s0 = opts[0]; + const int p0 = 0;//opts[3]; + const int d0 = 1;//opts[4]; + + const int64_t kernel_size = ggml_nelements(src0); + const int64_t input_size = ggml_nelements(src1); + const int64_t output_size = ggml_nelements(dst); + + conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); +} + + +void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int32_t * opts = (const int32_t *)dst->op_params; + + const int s0 = opts[0]; + const int p0 = 0;//opts[3]; + const int d0 = 1;//opts[4]; + + const int64_t kernel_size = ggml_nelements(src0); + const int64_t input_size = ggml_nelements(src1); + const int64_t output_size = ggml_nelements(dst); + + conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); +} + + diff --git a/src/ggml-cuda/conv-winograd.cuh b/src/ggml-cuda/conv-winograd.cuh new file mode 100644 index 000000000..8c3c388a5 --- /dev/null +++ b/src/ggml-cuda/conv-winograd.cuh @@ -0,0 +1,767 @@ +#include "common.cuh" + +// #define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256 +#define BC 8 +#define BN 32 +#define BK 64 +#define TW 8 +#define TH 16 +#define BN_p 138 + +__constant__ int access_f_s[2][32]; +__constant__ int access_s[2][32]; +__constant__ int tileid[2][32]; + + +// access_f_s +const int aux[2][32] = { + {0,0,1,1,2,2,3,3,4,4,5,5,6,6, + 7,7,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7}, + {8,8,9,9,10,10,11,11,12,12,13,13, + 14,14,15,15,8,8,9,9,10,10,11,11,12,12, + 13,13,14,14,15,15} + }; +// access_s +const int aux2[2][32] = { + {0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,2, + 3,2,3,2,3,2,3,2,3,2,3,2,3,2,3}, + {4,5,4,5,4,5,4,5,4,5,4, + 5,4,5,4,5,6,7,6,7,6,7,6,7, + 6,7,6,7,6,7,6,7} + }; + +const int tid[2][32] = { + {0,1,4,5,8,9,12,13,16,17,20,21,24,25,28,29, + 0,1,4,5,8,9,12,13,16,17,20,21,24,25,28,29}, + {2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31, + 2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31} + }; + +__device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag, float4 accumulator[][16]){ + accumulator[0][0].x += input_frag[0].x*filter_frag[0].x; + accumulator[0][0].y += input_frag[0].y*filter_frag[0].x; + accumulator[0][0].z += input_frag[0].z*filter_frag[0].x; + accumulator[0][0].w += input_frag[0].w*filter_frag[0].x; + + accumulator[0][1].x += input_frag[1].x*filter_frag[0].x; + accumulator[0][1].y += input_frag[1].y*filter_frag[0].x; + accumulator[0][1].z += input_frag[1].z*filter_frag[0].x; + accumulator[0][1].w += input_frag[1].w*filter_frag[0].x; + + accumulator[0][2].x += input_frag[0].x*filter_frag[0].y; + accumulator[0][2].y += input_frag[0].y*filter_frag[0].y; + accumulator[0][2].z += input_frag[0].z*filter_frag[0].y; + accumulator[0][2].w += input_frag[0].w*filter_frag[0].y; + + accumulator[0][3].x += input_frag[1].x*filter_frag[0].y; + accumulator[0][3].y += input_frag[1].y*filter_frag[0].y; + accumulator[0][3].z += input_frag[1].z*filter_frag[0].y; + accumulator[0][3].w += input_frag[1].w*filter_frag[0].y; + + accumulator[0][4].x += input_frag[0].x*filter_frag[0].z; + accumulator[0][4].y += input_frag[0].y*filter_frag[0].z; + accumulator[0][4].z += input_frag[0].z*filter_frag[0].z; + accumulator[0][4].w += input_frag[0].w*filter_frag[0].z; + + accumulator[0][5].x += input_frag[1].x*filter_frag[0].z; + accumulator[0][5].y += input_frag[1].y*filter_frag[0].z; + accumulator[0][5].z += input_frag[1].z*filter_frag[0].z; + accumulator[0][5].w += input_frag[1].w*filter_frag[0].z; + + accumulator[0][6].x += input_frag[0].x*filter_frag[0].w; + accumulator[0][6].y += input_frag[0].y*filter_frag[0].w; + accumulator[0][6].z += input_frag[0].z*filter_frag[0].w; + accumulator[0][6].w += input_frag[0].w*filter_frag[0].w; + + accumulator[0][7].x += input_frag[1].x*filter_frag[0].w; + accumulator[0][7].y += input_frag[1].y*filter_frag[0].w; + accumulator[0][7].z += input_frag[1].z*filter_frag[0].w; + accumulator[0][7].w += input_frag[1].w*filter_frag[0].w; + + // + accumulator[0][8].x += input_frag[0].x*filter_frag[1].x; + accumulator[0][8].y += input_frag[0].y*filter_frag[1].x; + accumulator[0][8].z += input_frag[0].z*filter_frag[1].x; + accumulator[0][8].w += input_frag[0].w*filter_frag[1].x; + + accumulator[0][9].x += input_frag[1].x*filter_frag[1].x; + accumulator[0][9].y += input_frag[1].y*filter_frag[1].x; + accumulator[0][9].z += input_frag[1].z*filter_frag[1].x; + accumulator[0][9].w += input_frag[1].w*filter_frag[1].x; + + accumulator[0][10].x += input_frag[0].x*filter_frag[1].y; + accumulator[0][10].y += input_frag[0].y*filter_frag[1].y; + accumulator[0][10].z += input_frag[0].z*filter_frag[1].y; + accumulator[0][10].w += input_frag[0].w*filter_frag[1].y; + + accumulator[0][11].x += input_frag[1].x*filter_frag[1].y; + accumulator[0][11].y += input_frag[1].y*filter_frag[1].y; + accumulator[0][11].z += input_frag[1].z*filter_frag[1].y; + accumulator[0][11].w += input_frag[1].w*filter_frag[1].y; + + accumulator[0][12].x += input_frag[0].x*filter_frag[1].z; + accumulator[0][12].y += input_frag[0].y*filter_frag[1].z; + accumulator[0][12].z += input_frag[0].z*filter_frag[1].z; + accumulator[0][12].w += input_frag[0].w*filter_frag[1].z; + + accumulator[0][13].x += input_frag[1].x*filter_frag[1].z; + accumulator[0][13].y += input_frag[1].y*filter_frag[1].z; + accumulator[0][13].z += input_frag[1].z*filter_frag[1].z; + accumulator[0][13].w += input_frag[1].w*filter_frag[1].z; + + accumulator[0][14].x += input_frag[0].x*filter_frag[1].w; + accumulator[0][14].y += input_frag[0].y*filter_frag[1].w; + accumulator[0][14].z += input_frag[0].z*filter_frag[1].w; + accumulator[0][14].w += input_frag[0].w*filter_frag[1].w; + + accumulator[0][15].x += input_frag[1].x*filter_frag[1].w; + accumulator[0][15].y += input_frag[1].y*filter_frag[1].w; + accumulator[0][15].z += input_frag[1].z*filter_frag[1].w; + accumulator[0][15].w += input_frag[1].w*filter_frag[1].w; + + ////// + accumulator[1][0].x += input_frag[2].x*filter_frag[2].x; + accumulator[1][0].y += input_frag[2].y*filter_frag[2].x; + accumulator[1][0].z += input_frag[2].z*filter_frag[2].x; + accumulator[1][0].w += input_frag[2].w*filter_frag[2].x; + + accumulator[1][1].x += input_frag[3].x*filter_frag[2].x; + accumulator[1][1].y += input_frag[3].y*filter_frag[2].x; + accumulator[1][1].z += input_frag[3].z*filter_frag[2].x; + accumulator[1][1].w += input_frag[3].w*filter_frag[2].x; + + accumulator[1][2].x += input_frag[2].x*filter_frag[2].y; + accumulator[1][2].y += input_frag[2].y*filter_frag[2].y; + accumulator[1][2].z += input_frag[2].z*filter_frag[2].y; + accumulator[1][2].w += input_frag[2].w*filter_frag[2].y; + + accumulator[1][3].x += input_frag[3].x*filter_frag[2].y; + accumulator[1][3].y += input_frag[3].y*filter_frag[2].y; + accumulator[1][3].z += input_frag[3].z*filter_frag[2].y; + accumulator[1][3].w += input_frag[3].w*filter_frag[2].y; + + accumulator[1][4].x += input_frag[2].x*filter_frag[2].z; + accumulator[1][4].y += input_frag[2].y*filter_frag[2].z; + accumulator[1][4].z += input_frag[2].z*filter_frag[2].z; + accumulator[1][4].w += input_frag[2].w*filter_frag[2].z; + + accumulator[1][5].x += input_frag[3].x*filter_frag[2].z; + accumulator[1][5].y += input_frag[3].y*filter_frag[2].z; + accumulator[1][5].z += input_frag[3].z*filter_frag[2].z; + accumulator[1][5].w += input_frag[3].w*filter_frag[2].z; + + accumulator[1][6].x += input_frag[2].x*filter_frag[2].w; + accumulator[1][6].y += input_frag[2].y*filter_frag[2].w; + accumulator[1][6].z += input_frag[2].z*filter_frag[2].w; + accumulator[1][6].w += input_frag[2].w*filter_frag[2].w; + + accumulator[1][7].x += input_frag[3].x*filter_frag[2].w; + accumulator[1][7].y += input_frag[3].y*filter_frag[2].w; + accumulator[1][7].z += input_frag[3].z*filter_frag[2].w; + accumulator[1][7].w += input_frag[3].w*filter_frag[2].w; + + // + accumulator[1][8].x += input_frag[2].x*filter_frag[3].x; + accumulator[1][8].y += input_frag[2].y*filter_frag[3].x; + accumulator[1][8].z += input_frag[2].z*filter_frag[3].x; + accumulator[1][8].w += input_frag[2].w*filter_frag[3].x; + + accumulator[1][9].x += input_frag[3].x*filter_frag[3].x; + accumulator[1][9].y += input_frag[3].y*filter_frag[3].x; + accumulator[1][9].z += input_frag[3].z*filter_frag[3].x; + accumulator[1][9].w += input_frag[3].w*filter_frag[3].x; + + accumulator[1][10].x += input_frag[2].x*filter_frag[3].y; + accumulator[1][10].y += input_frag[2].y*filter_frag[3].y; + accumulator[1][10].z += input_frag[2].z*filter_frag[3].y; + accumulator[1][10].w += input_frag[2].w*filter_frag[3].y; + + accumulator[1][11].x += input_frag[3].x*filter_frag[3].y; + accumulator[1][11].y += input_frag[3].y*filter_frag[3].y; + accumulator[1][11].z += input_frag[3].z*filter_frag[3].y; + accumulator[1][11].w += input_frag[3].w*filter_frag[3].y; + + accumulator[1][12].x += input_frag[2].x*filter_frag[3].z; + accumulator[1][12].y += input_frag[2].y*filter_frag[3].z; + accumulator[1][12].z += input_frag[2].z*filter_frag[3].z; + accumulator[1][12].w += input_frag[2].w*filter_frag[3].z; + + accumulator[1][13].x += input_frag[3].x*filter_frag[3].z; + accumulator[1][13].y += input_frag[3].y*filter_frag[3].z; + accumulator[1][13].z += input_frag[3].z*filter_frag[3].z; + accumulator[1][13].w += input_frag[3].w*filter_frag[3].z; + + accumulator[1][14].x += input_frag[2].x*filter_frag[3].w; + accumulator[1][14].y += input_frag[2].y*filter_frag[3].w; + accumulator[1][14].z += input_frag[2].z*filter_frag[3].w; + accumulator[1][14].w += input_frag[2].w*filter_frag[3].w; + + accumulator[1][15].x += input_frag[3].x*filter_frag[3].w; + accumulator[1][15].y += input_frag[3].y*filter_frag[3].w; + accumulator[1][15].z += input_frag[3].z*filter_frag[3].w; + accumulator[1][15].w += input_frag[3].w*filter_frag[3].w; + } + +extern "C" +{ + +__device__ __forceinline__ void transform_output_tile(float *pOutputs, float2 *C_tile, float2 *At, + int round, int c_tensor, int c_glb_offset, int i1, int i2, + unsigned short mask1, unsigned short mask2, int out_w) +{ + + c_tensor += (((round)/2)*32 + ((round)%2)*2)*c_glb_offset; + int x, x1; + + #pragma unroll + for(int j=0; j<4; j++){ + + At[j].x = C_tile[j].x + C_tile[4+j].x + C_tile[8+j].x; + At[j].y = C_tile[j].y + C_tile[4+j].y + C_tile[8+j].y; + + At[4+j].x = C_tile[4+j].x - C_tile[8+j].x - C_tile[12+j].x; + At[4+j].y = C_tile[4+j].y - C_tile[8+j].y - C_tile[12+j].y; + + } + + #pragma unroll + for(int i=0; i<2; i++){ + x = i*4; + x1 = i*((out_w-(out_w%2)) + (out_w%2)/2); + + if(mask1&(1<<(i*2))){ + pOutputs[x1 + c_tensor + i1] = At[x].x + At[x+1].x + At[x+2].x; + } + if(mask2&(1<<(i*2))){ + pOutputs[x1 + c_tensor + i2] = At[x].y + At[x+1].y + At[x+2].y; + } + if(mask1&(1<<(i*2+1))){ + pOutputs[x1 + c_tensor + i1 + 1] = At[x+1].x - At[x+2].x - At[x+3].x; + } + if(mask2&(1<<(i*2+1))){ + pOutputs[x1 + c_tensor + i2 + 1] = At[x+1].y - At[x+2].y - At[x+3].y; + } + } +} + +__device__ __forceinline__ unsigned short get_mask(int idd, int tiles_dim_w, int tiles_dim_h, + int tw, int th, int out_w, int out_h){ + + unsigned short mask = 0x000F; + // if((blockIdx.y/tiles_dim)==(tiles_dim-1) && out_w%2) mask&=0x0003; // pad bottom row + // if(!((blockIdx.y+1)%tiles_dim) && out_w%2) mask&=0X0005; // pad right col + // if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row + // if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col + if(tiles_dim_w % tw == 0 && tiles_dim_h % th == 0){ + if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row + if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col + }else if(tiles_dim_w % tw == 0){ + int k = out_h % TH; + int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles + if(blockIdx.y==gridDim.y-1 && (idd / tw) == k1-1 && k%2) mask&=0x0003; // pad bottom row + if(blockIdx.y==gridDim.y-1 && (idd / tw) > k1-1) mask &= 0x0; //pad all zeros since this tile does not exist + }else if(tiles_dim_h % th == 0){ + int k = out_w % TW; + int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles + if(blockIdx.x==gridDim.x-1 && (idd % tw) == k1-1 && k%2) mask&=0X0005; // pad right col + if(blockIdx.x==gridDim.x-1 && (idd % tw) > k1-1) mask&=0X0; // pad all zeroes + }else{ + int kh = out_h % TH; + int kw = out_w % TW; + int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles + int kw1 = kw % 2 ? (kw+1)/2 : kw/2; + if(blockIdx.y==gridDim.y-1 && (idd / tw) == kh1-1 && kh%2) mask&=0x0003; // pad bottom row + if(blockIdx.x==gridDim.x-1 && (idd % tw) == kw1-1 && kw%2) mask&=0X0005; // pad right col + if(blockIdx.y==gridDim.y-1 && (idd / tw) > kh1-1) mask &= 0x0; //pad all zeros since this tile does not exist + if(blockIdx.x==gridDim.x-1 && (idd % tw) > kw1-1) mask &= 0X0; // pad all zeroes + } + return mask; +} + +__device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float *C, +int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, +float4 *input_frag_mem, float4* filter_frag_mem){ + + float2 *output_smem = (float2 *) shared_mem; + float2 *accumulator = (float2 *) acumm_smem; + float2 *C_out = (float2*)C; + + float2 *C_tile = (float2*) input_frag_mem; + float2 *At = (float2*) filter_frag_mem; + + int idd1 = tileid[0][threadIdx.x]; + int id1 = (idd1 % tw) * 2 + (idd1 / tw) * out_w * 2; + int idd2 = tileid[1][threadIdx.x]; + int id2 = (idd2 % tw) * 2 + (idd2 / tw) * out_w * 2; + + // unsigned short mask1 = 0x000F; + unsigned short mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned short mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + + // output transpose step + int t=0; + int acumm1, acumm2; + // For transposing + //acumm1 = access_s_out[Inx]; //* 4 + acumm1 = ((threadIdx.x%8)/2)*34 + threadIdx.x%2 + (threadIdx.x/16)*2 + ((threadIdx.x/8)%2)*8; + acumm2 = acumm1+4; + + int acumm4 = BN_p*16 ; //*4 + int idx = threadIdx.y * BN_p; + int idx2 = idx + BN_p*8; //(BN_p*2 *8)/2 + + // For transformating + int offset = BN_p *2; //*2/2 + int init = ( (threadIdx.y/4)*BN_p*16 + (threadIdx.y%4)*(32+2) ) *2 + threadIdx.x; + + int c_glb_offset = out_h*out_w; + // int c_tensor = blockIdx.z*c_glb_offset*BK + (blockIdx.y%tiles_dim)*2 + (blockIdx.y/tiles_dim)*out_w*2 + + // blockIdx.x*BN + (threadIdx.x%16)*2+ + // ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; + + int tx = TW, ty = TH; + // int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; + // int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + + // threadIdx.y*(in_h*in_w) - (in_w+1); + + int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * tx + blockIdx.y * out_w * ty + + // (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 + + ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; + + // c_tensor/=2; + + + int target = 16; + + #pragma unroll + for(int round=0; round<4; round++){ + + *( (float2*) (output_smem + idx + acumm1) ) = *(accumulator+t); + *( (float2*) (output_smem + idx + acumm1 + 16) ) = *(accumulator+t+1); // float 4, t + *( (float2*) (output_smem + idx + acumm2) ) = *(accumulator+t+2); + *( (float2*) (output_smem + idx + acumm2 + 16) ) = *(accumulator+t+3); // float 4, t+1 + + + *( (float2*) (output_smem + idx2 + acumm1) ) = *(accumulator+t+32); + *( (float2*) (output_smem + idx2 + acumm1 + 16) ) = *(accumulator+t+33); // float 4, t+16 + *( (float2*) (output_smem + idx2 + acumm2) ) = *(accumulator+t+34); + *( (float2*) (output_smem + idx2 + acumm2 + 16) ) = *(accumulator+t+35); // float 4, t+17 + + // the above 8 float2 will be consumed by theadIdx.y = [0,1,2,3] + + // the following 8 float2 will be consumed by theadIdx.y = [4,5,6,7] + + *( (float2*) (output_smem + idx + acumm4 + acumm1) ) = *(accumulator+t+4); + *( (float2*) (output_smem + idx + acumm4 + acumm1 + 16) ) = *(accumulator+t+5); // float 4, t+2 + *( (float2*) (output_smem + idx + acumm4 + acumm2) ) = *(accumulator+t+6); + *( (float2*) (output_smem + idx + acumm4 + acumm2 + 16) ) = *(accumulator+t+7); // float 4, t+3 + + *( (float2*) (output_smem + idx2 + acumm4 + acumm1) ) = *(accumulator+t+36); + *( (float2*) (output_smem + idx2 + acumm4 + acumm1 + 16) ) = *(accumulator+t+37); // float 4, t+18 + *( (float2*) (output_smem + idx2 + acumm4 + acumm2) ) = *(accumulator+t+38); + *( (float2*) (output_smem + idx2 + acumm4 + acumm2 + 16) ) = *(accumulator+t+39); // float 4, t+19 + + + + t+=8; + + __syncthreads(); + + + + // for output transformation, the role of threadIdx.y changes again: + // in the main loop, different threadIdx.y deal with different element of the 4x4 tile + // here, they are for 4 different groups of lane ids from optSTS64 layout + // for init (and init+32), we need to identify its tile number (0-31) within the supertile + // first, from init, find out from which threadIdx.x it comes. + + // now we got l, which is the land id which computed accumulated sum for the tile element + // each lane id (or threadIdx.x) computed 8 tiles which are distributed into 4 locations spreading + // over the smem. We need to find which of the 8 the current tile is. + // use tileid table to figure out + // int id1 = tileid[0][l]; + + + // for 2nd tile + + + for(int i=0; i<16; i++){ + C_tile[i].x = shared_mem[i*offset + init]; + C_tile[i].y = shared_mem[i*offset + init + 32]; + + } + + + // transform output tiles + transform_output_tile(C, C_tile, At, round, c_tensor, c_glb_offset, id1, id2, mask1, mask2, out_w); + __syncthreads(); + } +} + + +// Set of functions per row in Gw product +__device__ float f_row1(float *Gw, int j){ + return Gw[j]; + } + __device__ float f_row2(float *Gw, int j){ + return 0.5*(Gw[j] + Gw[6+j] + Gw[3+j]); + } + __device__ float f_row3(float *Gw, int j){ + return 0.5*(Gw[j] + Gw[6+j] - Gw[3+j]); + } + __device__ float f_row4(float *Gw, int j){ + return Gw[6+j]; + } + // Set of functions per column in GwGt product + __device__ float f_col1(float *Gw, int j){ + return Gw[j]; + } + __device__ float f_col2(float *Gw, int j){ + return 0.5*(Gw[j] + Gw[j+2] + Gw[j+1]); + } + __device__ float f_col3(float *Gw, int j){ + return 0.5*(Gw[j] + Gw[j+2] - Gw[j+1]); + } + __device__ float f_col4(float *Gw, int j){ + return Gw[j+2]; + } + + typedef float(*pointFunction_t)(float *, int); + + __global__ void FX(float *pInputs, float *pOutputs, int filt_k, + int filt_c, int filt_h, int filt_w, int alpha){ + int Inx = threadIdx.x, Iny = threadIdx.y; + int TileX = blockIdx.x, TileY = blockIdx.y; + + int c_glb_offset = filt_k*filt_h*filt_w; + int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx; + int c_glb_offset_s = filt_k*4*4; + int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx; + + float Gw[21]; //9+12. In registers + float *Gw_buffer = Gw+9; + + pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4}; + pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; + + for(int bk=0; bk= BC) return; + + // each thread in row 0 puts its first element of 1st filter tile(loaded by the thread) in smem + // taking 32 slots + // then puts its first element of 2nd filter tile immediately after, taking another 32 slots + // then followed by threads in row 1, 2.. until 7 + + // Note the next element is BK*BC (8*64) slots away, then another BK*BC .... + // for every 64 values, the first 32 belongs to filter tile 1, the next 32 for filter tile 2 + + for(int k=0; k<2; k++){ // prefetch 2 filter tiles/thread + for(int i=0; i<4; i++){ + #pragma unroll + for(int j=0; j<4; j++){ + pOutputs[c_tensor_s + i*c_offset_s*4 + j*c_offset_s] = tiles[k*16 + i*4 + j]; + } + } + // 2nd tile right behind the 1st? + c_tensor_s += BN; // BN has nothing to do with input tiles + } + +} + +__device__ __forceinline__ void prefetch_filter_tile(float *pInputs, float *tiles, int filt_k){ + + int c_tensor = blockIdx.z*BK + (threadIdx.y*filt_k<<4) + threadIdx.x; // Iny*filt_k*4*4 + // each threadIdx.y corresponds to one channel; there are 8 different threadIdx.y so 8 channels + + //each thread (32 threads in x direction) loads 2 kernel tiles (32 in K direction apart) + // save the two tiles in a float[32] register, float[16] for each + + int acumm; + #pragma unroll + for(int i=0; i<4; i++){ + acumm = (i*filt_k<<2); + #pragma unroll + for(int j=0; j<4; j++){ + tiles[(i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor]; + tiles[16 + (i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor+BN]; + } + } +} + +__device__ __forceinline__ void prefetch_input_tile(float *pInputs, float *tile, int in_h, + int in_w, int tw, int th, unsigned short mask){ + + // load one input tile + int tx = TW, ty = TH; + int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; + int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + + threadIdx.y*(in_h*in_w) - (in_w+1); + + int acumm,x; + + + if(mask==0xFFFF){ + #pragma unroll + for(int i=0; i<4; i++){ + acumm = i*in_w; + #pragma unroll + for(int j=0; j<4; j++){ + tile[(i<<2) + j] = pInputs[acumm + j + c_tensor]; + } + } + + } else { + for(int i=0; i<4; i++){ + acumm = i*in_w; + #pragma unroll + for(int j=0; j<4; j++){ + x = (i<<2) + j; + tile[x] = 0.f; + if(mask&(1< k1-1) m &= 0x0; //pad all zeros since this tile does not exist + }else if(tiles_dim_h % Y == 0){ + int k = in_w % TW; + int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 8*k1 tiles + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == k1-1) m &= (!(k%2))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > k1-1) m &= 0x0; //pad all zeros since this tile does not exist + }else{ + int kh = in_h % TH; + int kw = in_w % TW; + int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles + int kw1 = kw % 2 ? (kw+1)/2 : kw/2; + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == kh1-1) m &= (!(kh%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X > kh1-1) m &= 0x0; //pad all zeros since this tile does not exist + if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == kw1-1) m &= (!(kw%2))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > kw1-1) m &= 0x0; //pad all zeros since this tile does not exist + } + if(blockIdx.x==0 && (threadIdx.x % X) == 0) m &=0xeeee; // pad left col + + float img_tile[16]; // Prefetch input from GMEM + float filter_tile[32]; // Prefetch filter from GMEM + + float4 input_frag_mem[8]; //2*2(2*8/4) Data to do Outer Product + prefetch f. SMEM (double_buffer) + float4 filter_frag_mem[8]; //2*2 Data to do Outer Product + prefetch f. SMEM (double_buffer) + float4 accumulator[2][16] = {0.0f}; // Accumulators + + float4 *A_frag; // Input data pointer + int frag_offset = 2 * (BN*BC); // (2=8/4) SMEM input read offset + + float4 *B_frag; // Filter data pointer + int f_frag_offset = 2 * (BC*BK); // (2=8/4 with 4 being float4) SMEM filter read offset + + + float4 *input_frag = (float4*) input_frag_mem; + float4 *filter_frag = (float4*) filter_frag_mem; + + float4 *swap_filter; + float4 *swap_input; + + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_filter_tile(B, filter_tile, filt_k); + + float4 *input_frag_buffer = (float4*) (input_frag+4); + float4 *filter_frag_buffer = (float4*) (filter_frag+4); + + // Mainloop - iterates over the entire K dimension - not unrolled + for(int iter=0; iter>>(w, Ww, filt_k, filt_c, filt_h, filt_w, alpha); + + // each thread block will load 32 tiles (4x4) from the single image input + // we let X*Y = 32 and arbitraraly pick X = 4 and Y = 8 + Winograd_kernel<<>>(k, Ww, C, + tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); + + return cudaGetLastError(); +} + +} + + +void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + From 02a3cb1bf14129825d94a97124b8a2a1c8cfb5c4 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 27 Sep 2024 11:44:32 -0400 Subject: [PATCH 04/23] winograd build ok --- include/ggml.h | 15 + src/ggml-cuda.cu | 9 + src/ggml-cuda/conv-winograd.cu | 834 +++++++++++++++++++++++++++++--- src/ggml-cuda/conv-winograd.cuh | 723 --------------------------- src/ggml.c | 37 +- tests/test-conv2d-winograd.cpp | 395 +++++++++++++++ 6 files changed, 1223 insertions(+), 790 deletions(-) create mode 100644 tests/test-conv2d-winograd.cpp diff --git a/include/ggml.h b/include/ggml.h index ce8368947..3bba48a79 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -1698,6 +1698,21 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b, int stride); + + GGML_API struct ggml_tensor * ggml_winograd_stage0( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_winograd_stage1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_conv_2d_3x3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + enum ggml_op_pool { GGML_OP_POOL_MAX, diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 0bb7f2d99..c70f5b70c 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -10,6 +10,7 @@ #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" +#include "ggml-cuda/conv-winograd.cuh" #include "ggml-cuda/convert.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" @@ -2331,6 +2332,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CONV_TRANSPOSE_1D: ggml_cuda_op_conv_transpose_1d(ctx,dst); break; + case GGML_OP_WINOGRAD_STAGE0: + ggml_cuda_op_winograd_stage0(ctx, dst); + break; + case GGML_OP_WINOGRAD_STAGE1: + ggml_cuda_op_winograd_stage1(ctx, dst); + break; case GGML_OP_POOL_2D: ggml_cuda_op_pool2d(ctx, dst); break; @@ -2950,6 +2957,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } return false; } break; + case GGML_OP_WINOGRAD_STAGE0: + case GGML_OP_WINOGRAD_STAGE1: case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index 2ff57f822..eaa2b9c42 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -1,93 +1,791 @@ -#include "conv-transpose-1d.cuh" +#include "conv-winograd.cuh" +#include "convert.cuh" + +__device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag, float4 accumulator[][16]){ + accumulator[0][0].x += input_frag[0].x*filter_frag[0].x; + accumulator[0][0].y += input_frag[0].y*filter_frag[0].x; + accumulator[0][0].z += input_frag[0].z*filter_frag[0].x; + accumulator[0][0].w += input_frag[0].w*filter_frag[0].x; + + accumulator[0][1].x += input_frag[1].x*filter_frag[0].x; + accumulator[0][1].y += input_frag[1].y*filter_frag[0].x; + accumulator[0][1].z += input_frag[1].z*filter_frag[0].x; + accumulator[0][1].w += input_frag[1].w*filter_frag[0].x; + + accumulator[0][2].x += input_frag[0].x*filter_frag[0].y; + accumulator[0][2].y += input_frag[0].y*filter_frag[0].y; + accumulator[0][2].z += input_frag[0].z*filter_frag[0].y; + accumulator[0][2].w += input_frag[0].w*filter_frag[0].y; + + accumulator[0][3].x += input_frag[1].x*filter_frag[0].y; + accumulator[0][3].y += input_frag[1].y*filter_frag[0].y; + accumulator[0][3].z += input_frag[1].z*filter_frag[0].y; + accumulator[0][3].w += input_frag[1].w*filter_frag[0].y; + + accumulator[0][4].x += input_frag[0].x*filter_frag[0].z; + accumulator[0][4].y += input_frag[0].y*filter_frag[0].z; + accumulator[0][4].z += input_frag[0].z*filter_frag[0].z; + accumulator[0][4].w += input_frag[0].w*filter_frag[0].z; + + accumulator[0][5].x += input_frag[1].x*filter_frag[0].z; + accumulator[0][5].y += input_frag[1].y*filter_frag[0].z; + accumulator[0][5].z += input_frag[1].z*filter_frag[0].z; + accumulator[0][5].w += input_frag[1].w*filter_frag[0].z; + + accumulator[0][6].x += input_frag[0].x*filter_frag[0].w; + accumulator[0][6].y += input_frag[0].y*filter_frag[0].w; + accumulator[0][6].z += input_frag[0].z*filter_frag[0].w; + accumulator[0][6].w += input_frag[0].w*filter_frag[0].w; + + accumulator[0][7].x += input_frag[1].x*filter_frag[0].w; + accumulator[0][7].y += input_frag[1].y*filter_frag[0].w; + accumulator[0][7].z += input_frag[1].z*filter_frag[0].w; + accumulator[0][7].w += input_frag[1].w*filter_frag[0].w; + + // + accumulator[0][8].x += input_frag[0].x*filter_frag[1].x; + accumulator[0][8].y += input_frag[0].y*filter_frag[1].x; + accumulator[0][8].z += input_frag[0].z*filter_frag[1].x; + accumulator[0][8].w += input_frag[0].w*filter_frag[1].x; + + accumulator[0][9].x += input_frag[1].x*filter_frag[1].x; + accumulator[0][9].y += input_frag[1].y*filter_frag[1].x; + accumulator[0][9].z += input_frag[1].z*filter_frag[1].x; + accumulator[0][9].w += input_frag[1].w*filter_frag[1].x; + + accumulator[0][10].x += input_frag[0].x*filter_frag[1].y; + accumulator[0][10].y += input_frag[0].y*filter_frag[1].y; + accumulator[0][10].z += input_frag[0].z*filter_frag[1].y; + accumulator[0][10].w += input_frag[0].w*filter_frag[1].y; + + accumulator[0][11].x += input_frag[1].x*filter_frag[1].y; + accumulator[0][11].y += input_frag[1].y*filter_frag[1].y; + accumulator[0][11].z += input_frag[1].z*filter_frag[1].y; + accumulator[0][11].w += input_frag[1].w*filter_frag[1].y; + + accumulator[0][12].x += input_frag[0].x*filter_frag[1].z; + accumulator[0][12].y += input_frag[0].y*filter_frag[1].z; + accumulator[0][12].z += input_frag[0].z*filter_frag[1].z; + accumulator[0][12].w += input_frag[0].w*filter_frag[1].z; + + accumulator[0][13].x += input_frag[1].x*filter_frag[1].z; + accumulator[0][13].y += input_frag[1].y*filter_frag[1].z; + accumulator[0][13].z += input_frag[1].z*filter_frag[1].z; + accumulator[0][13].w += input_frag[1].w*filter_frag[1].z; + + accumulator[0][14].x += input_frag[0].x*filter_frag[1].w; + accumulator[0][14].y += input_frag[0].y*filter_frag[1].w; + accumulator[0][14].z += input_frag[0].z*filter_frag[1].w; + accumulator[0][14].w += input_frag[0].w*filter_frag[1].w; + + accumulator[0][15].x += input_frag[1].x*filter_frag[1].w; + accumulator[0][15].y += input_frag[1].y*filter_frag[1].w; + accumulator[0][15].z += input_frag[1].z*filter_frag[1].w; + accumulator[0][15].w += input_frag[1].w*filter_frag[1].w; + + ////// + accumulator[1][0].x += input_frag[2].x*filter_frag[2].x; + accumulator[1][0].y += input_frag[2].y*filter_frag[2].x; + accumulator[1][0].z += input_frag[2].z*filter_frag[2].x; + accumulator[1][0].w += input_frag[2].w*filter_frag[2].x; + + accumulator[1][1].x += input_frag[3].x*filter_frag[2].x; + accumulator[1][1].y += input_frag[3].y*filter_frag[2].x; + accumulator[1][1].z += input_frag[3].z*filter_frag[2].x; + accumulator[1][1].w += input_frag[3].w*filter_frag[2].x; + + accumulator[1][2].x += input_frag[2].x*filter_frag[2].y; + accumulator[1][2].y += input_frag[2].y*filter_frag[2].y; + accumulator[1][2].z += input_frag[2].z*filter_frag[2].y; + accumulator[1][2].w += input_frag[2].w*filter_frag[2].y; + + accumulator[1][3].x += input_frag[3].x*filter_frag[2].y; + accumulator[1][3].y += input_frag[3].y*filter_frag[2].y; + accumulator[1][3].z += input_frag[3].z*filter_frag[2].y; + accumulator[1][3].w += input_frag[3].w*filter_frag[2].y; + + accumulator[1][4].x += input_frag[2].x*filter_frag[2].z; + accumulator[1][4].y += input_frag[2].y*filter_frag[2].z; + accumulator[1][4].z += input_frag[2].z*filter_frag[2].z; + accumulator[1][4].w += input_frag[2].w*filter_frag[2].z; + + accumulator[1][5].x += input_frag[3].x*filter_frag[2].z; + accumulator[1][5].y += input_frag[3].y*filter_frag[2].z; + accumulator[1][5].z += input_frag[3].z*filter_frag[2].z; + accumulator[1][5].w += input_frag[3].w*filter_frag[2].z; + + accumulator[1][6].x += input_frag[2].x*filter_frag[2].w; + accumulator[1][6].y += input_frag[2].y*filter_frag[2].w; + accumulator[1][6].z += input_frag[2].z*filter_frag[2].w; + accumulator[1][6].w += input_frag[2].w*filter_frag[2].w; + + accumulator[1][7].x += input_frag[3].x*filter_frag[2].w; + accumulator[1][7].y += input_frag[3].y*filter_frag[2].w; + accumulator[1][7].z += input_frag[3].z*filter_frag[2].w; + accumulator[1][7].w += input_frag[3].w*filter_frag[2].w; + + // + accumulator[1][8].x += input_frag[2].x*filter_frag[3].x; + accumulator[1][8].y += input_frag[2].y*filter_frag[3].x; + accumulator[1][8].z += input_frag[2].z*filter_frag[3].x; + accumulator[1][8].w += input_frag[2].w*filter_frag[3].x; + + accumulator[1][9].x += input_frag[3].x*filter_frag[3].x; + accumulator[1][9].y += input_frag[3].y*filter_frag[3].x; + accumulator[1][9].z += input_frag[3].z*filter_frag[3].x; + accumulator[1][9].w += input_frag[3].w*filter_frag[3].x; + + accumulator[1][10].x += input_frag[2].x*filter_frag[3].y; + accumulator[1][10].y += input_frag[2].y*filter_frag[3].y; + accumulator[1][10].z += input_frag[2].z*filter_frag[3].y; + accumulator[1][10].w += input_frag[2].w*filter_frag[3].y; + + accumulator[1][11].x += input_frag[3].x*filter_frag[3].y; + accumulator[1][11].y += input_frag[3].y*filter_frag[3].y; + accumulator[1][11].z += input_frag[3].z*filter_frag[3].y; + accumulator[1][11].w += input_frag[3].w*filter_frag[3].y; + + accumulator[1][12].x += input_frag[2].x*filter_frag[3].z; + accumulator[1][12].y += input_frag[2].y*filter_frag[3].z; + accumulator[1][12].z += input_frag[2].z*filter_frag[3].z; + accumulator[1][12].w += input_frag[2].w*filter_frag[3].z; + + accumulator[1][13].x += input_frag[3].x*filter_frag[3].z; + accumulator[1][13].y += input_frag[3].y*filter_frag[3].z; + accumulator[1][13].z += input_frag[3].z*filter_frag[3].z; + accumulator[1][13].w += input_frag[3].w*filter_frag[3].z; + + accumulator[1][14].x += input_frag[2].x*filter_frag[3].w; + accumulator[1][14].y += input_frag[2].y*filter_frag[3].w; + accumulator[1][14].z += input_frag[2].z*filter_frag[3].w; + accumulator[1][14].w += input_frag[2].w*filter_frag[3].w; + + accumulator[1][15].x += input_frag[3].x*filter_frag[3].w; + accumulator[1][15].y += input_frag[3].y*filter_frag[3].w; + accumulator[1][15].z += input_frag[3].z*filter_frag[3].w; + accumulator[1][15].w += input_frag[3].w*filter_frag[3].w; + } + +extern "C" +{ + +__device__ __forceinline__ void transform_output_tile(float *pOutputs, float2 *C_tile, float2 *At, + int round, int c_tensor, int c_glb_offset, int i1, int i2, + unsigned short mask1, unsigned short mask2, int out_w) +{ + + c_tensor += (((round)/2)*32 + ((round)%2)*2)*c_glb_offset; + int x, x1; + + #pragma unroll + for(int j=0; j<4; j++){ + + At[j].x = C_tile[j].x + C_tile[4+j].x + C_tile[8+j].x; + At[j].y = C_tile[j].y + C_tile[4+j].y + C_tile[8+j].y; + + At[4+j].x = C_tile[4+j].x - C_tile[8+j].x - C_tile[12+j].x; + At[4+j].y = C_tile[4+j].y - C_tile[8+j].y - C_tile[12+j].y; + + } + + #pragma unroll + for(int i=0; i<2; i++){ + x = i*4; + x1 = i*((out_w-(out_w%2)) + (out_w%2)/2); + + if(mask1&(1<<(i*2))){ + pOutputs[x1 + c_tensor + i1] = At[x].x + At[x+1].x + At[x+2].x; + } + if(mask2&(1<<(i*2))){ + pOutputs[x1 + c_tensor + i2] = At[x].y + At[x+1].y + At[x+2].y; + } + if(mask1&(1<<(i*2+1))){ + pOutputs[x1 + c_tensor + i1 + 1] = At[x+1].x - At[x+2].x - At[x+3].x; + } + if(mask2&(1<<(i*2+1))){ + pOutputs[x1 + c_tensor + i2 + 1] = At[x+1].y - At[x+2].y - At[x+3].y; + } + } +} -static __global__ void conv_transpose_1d_kernel( - const int s0, const int p0, const int d0, const int output_size, - const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, - const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, - const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const float * src0, const float * src1, float * dst) { - int global_index = threadIdx.x + blockIdx.x * blockDim.x; - if (global_index >= output_size) { - return; +__device__ __forceinline__ unsigned short get_mask(int idd, int tiles_dim_w, int tiles_dim_h, + int tw, int th, int out_w, int out_h){ + + unsigned short mask = 0x000F; + // if((blockIdx.y/tiles_dim)==(tiles_dim-1) && out_w%2) mask&=0x0003; // pad bottom row + // if(!((blockIdx.y+1)%tiles_dim) && out_w%2) mask&=0X0005; // pad right col + // if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row + // if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col + if(tiles_dim_w % tw == 0 && tiles_dim_h % th == 0){ + if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row + if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col + }else if(tiles_dim_w % tw == 0){ + int k = out_h % TH; + int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles + if(blockIdx.y==gridDim.y-1 && (idd / tw) == k1-1 && k%2) mask&=0x0003; // pad bottom row + if(blockIdx.y==gridDim.y-1 && (idd / tw) > k1-1) mask &= 0x0; //pad all zeros since this tile does not exist + }else if(tiles_dim_h % th == 0){ + int k = out_w % TW; + int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles + if(blockIdx.x==gridDim.x-1 && (idd % tw) == k1-1 && k%2) mask&=0X0005; // pad right col + if(blockIdx.x==gridDim.x-1 && (idd % tw) > k1-1) mask&=0X0; // pad all zeroes + }else{ + int kh = out_h % TH; + int kw = out_w % TW; + int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles + int kw1 = kw % 2 ? (kw+1)/2 : kw/2; + if(blockIdx.y==gridDim.y-1 && (idd / tw) == kh1-1 && kh%2) mask&=0x0003; // pad bottom row + if(blockIdx.x==gridDim.x-1 && (idd % tw) == kw1-1 && kw%2) mask&=0X0005; // pad right col + if(blockIdx.y==gridDim.y-1 && (idd / tw) > kh1-1) mask &= 0x0; //pad all zeros since this tile does not exist + if(blockIdx.x==gridDim.x-1 && (idd % tw) > kw1-1) mask &= 0X0; // pad all zeroes + } + return mask; +} + +__device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float *C, +int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, +float4 *input_frag_mem, float4* filter_frag_mem){ + + float2 *output_smem = (float2 *) shared_mem; + float2 *accumulator = (float2 *) acumm_smem; + float2 *C_out = (float2*)C; + + float2 *C_tile = (float2*) input_frag_mem; + float2 *At = (float2*) filter_frag_mem; + // for output transformation, the role of threadIdx.y changes again: + // in the main loop, different threadIdx.y deal with different element of the 4x4 tile + // here, they are for 4 different groups of lane ids from optSTS64 layout + // for init (and init+32), we need to identify its tile number (0-31) within the supertile + // first, from init, find out from which threadIdx.x it comes. + + // now we got l, which is the land id which computed accumulated sum for the tile element + // each lane id (or threadIdx.x) computed 8 tiles which are distributed into 4 locations spreading + // over the smem. We need to find which of the 8 the current tile is. + // use tileid table to figure out + + // for 2nd tile + + int idd1 = tileid[0][threadIdx.x]; + int id1 = (idd1 % tw) * 2 + (idd1 / tw) * out_w * 2; + int idd2 = tileid[1][threadIdx.x]; + int id2 = (idd2 % tw) * 2 + (idd2 / tw) * out_w * 2; + + // unsigned short mask1 = 0x000F; + unsigned short mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned short mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + + // output transpose step + int t=0; + int acumm1, acumm2; + // For transposing + //acumm1 = access_s_out[Inx]; //* 4 + acumm1 = ((threadIdx.x%8)/2)*34 + threadIdx.x%2 + (threadIdx.x/16)*2 + ((threadIdx.x/8)%2)*8; + acumm2 = acumm1+4; + + int acumm4 = BN_p*16 ; //*4 + int idx = threadIdx.y * BN_p; + int idx2 = idx + BN_p*8; //(BN_p*2 *8)/2 + + // For transformating + int offset = BN_p *2; //*2/2 + int init = ( (threadIdx.y/4)*BN_p*16 + (threadIdx.y%4)*(32+2) ) *2 + threadIdx.x; + + int c_glb_offset = out_h*out_w; + // int c_tensor = blockIdx.z*c_glb_offset*BK + (blockIdx.y%tiles_dim)*2 + (blockIdx.y/tiles_dim)*out_w*2 + + // blockIdx.x*BN + (threadIdx.x%16)*2+ + // ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; + + int tx = TW, ty = TH; + // int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; + // int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + + // threadIdx.y*(in_h*in_w) - (in_w+1); + + int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * tx + blockIdx.y * out_w * ty + + // (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 + + ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; + + #pragma unroll + for(int round=0; round<4; round++){ + + *( (float2*) (output_smem + idx + acumm1) ) = *(accumulator+t); + *( (float2*) (output_smem + idx + acumm1 + 16) ) = *(accumulator+t+1); // float 4, t + *( (float2*) (output_smem + idx + acumm2) ) = *(accumulator+t+2); + *( (float2*) (output_smem + idx + acumm2 + 16) ) = *(accumulator+t+3); // float 4, t+1 + + + *( (float2*) (output_smem + idx2 + acumm1) ) = *(accumulator+t+32); + *( (float2*) (output_smem + idx2 + acumm1 + 16) ) = *(accumulator+t+33); // float 4, t+16 + *( (float2*) (output_smem + idx2 + acumm2) ) = *(accumulator+t+34); + *( (float2*) (output_smem + idx2 + acumm2 + 16) ) = *(accumulator+t+35); // float 4, t+17 + + // the above 8 float2 will be consumed by theadIdx.y = [0,1,2,3] + + // the following 8 float2 will be consumed by theadIdx.y = [4,5,6,7] + + *( (float2*) (output_smem + idx + acumm4 + acumm1) ) = *(accumulator+t+4); + *( (float2*) (output_smem + idx + acumm4 + acumm1 + 16) ) = *(accumulator+t+5); // float 4, t+2 + *( (float2*) (output_smem + idx + acumm4 + acumm2) ) = *(accumulator+t+6); + *( (float2*) (output_smem + idx + acumm4 + acumm2 + 16) ) = *(accumulator+t+7); // float 4, t+3 + + *( (float2*) (output_smem + idx2 + acumm4 + acumm1) ) = *(accumulator+t+36); + *( (float2*) (output_smem + idx2 + acumm4 + acumm1 + 16) ) = *(accumulator+t+37); // float 4, t+18 + *( (float2*) (output_smem + idx2 + acumm4 + acumm2) ) = *(accumulator+t+38); + *( (float2*) (output_smem + idx2 + acumm4 + acumm2 + 16) ) = *(accumulator+t+39); // float 4, t+19 + + + + t+=8; + + __syncthreads(); + + + + + for(int i=0; i<16; i++){ + C_tile[i].x = shared_mem[i*offset + init]; + C_tile[i].y = shared_mem[i*offset + init + 32]; + } + - int out_index = global_index / dst_ne0; + // transform output tiles + transform_output_tile(C, C_tile, At, round, c_tensor, c_glb_offset, id1, id2, mask1, mask2, out_w); + __syncthreads(); + } +} - float accumulator = 0; - for (int c = 0; c < src0_ne2; c++) { - int idx = global_index % dst_ne0; +// Set of functions per row in Gw product +__device__ float f_row1(float *Gw, int j){ + return Gw[j]; + } + __device__ float f_row2(float *Gw, int j){ + return 0.5*(Gw[j] + Gw[6+j] + Gw[3+j]); + } + __device__ float f_row3(float *Gw, int j){ + return 0.5*(Gw[j] + Gw[6+j] - Gw[3+j]); + } + __device__ float f_row4(float *Gw, int j){ + return Gw[6+j]; + } + // Set of functions per column in GwGt product + __device__ float f_col1(float *Gw, int j){ + return Gw[j]; + } + __device__ float f_col2(float *Gw, int j){ + return 0.5*(Gw[j] + Gw[j+2] + Gw[j+1]); + } + __device__ float f_col3(float *Gw, int j){ + return 0.5*(Gw[j] + Gw[j+2] - Gw[j+1]); + } + __device__ float f_col4(float *Gw, int j){ + return Gw[j+2]; + } + + typedef float(*pointFunction_t)(float *, int); + + __global__ void FX(const float *pInputs, float *pOutputs, int filt_k, + int filt_c, int filt_h, int filt_w){ + int Inx = threadIdx.x, Iny = threadIdx.y; + int TileX = blockIdx.x, TileY = blockIdx.y; + + int c_glb_offset = filt_k*filt_h*filt_w; + int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx; + int c_glb_offset_s = filt_k*4*4; + int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx; + + float Gw[21]; //9+12. In registers + float *Gw_buffer = Gw+9; + + pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4}; + pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; + + for(int bk=0; bk= i*s0 && idx < i*s0 + src0_ne0)) { - continue; - } - int weight_idx = idx - i*s0; +__device__ __forceinline__ void load_filter_tile(float *tiles, float *pOutputs, + int filt_c, int filt_k){ + + int c_tensor_s = threadIdx.y*BK + threadIdx.x; + int c_offset_s = BK*BC; + // if(threadIdx.y >= BC) return; + + // each thread in row 0 puts its first element of 1st filter tile(loaded by the thread) in smem + // taking 32 slots + // then puts its first element of 2nd filter tile immediately after, taking another 32 slots + // then followed by threads in row 1, 2.. until 7 + + // Note the next element is BK*BC (8*64) slots away, then another BK*BC .... + // for every 64 values, the first 32 belongs to filter tile 1, the next 32 for filter tile 2 + + for(int k=0; k<2; k++){ // prefetch 2 filter tiles/thread + for(int i=0; i<4; i++){ + #pragma unroll + for(int j=0; j<4; j++){ + pOutputs[c_tensor_s + i*c_offset_s*4 + j*c_offset_s] = tiles[k*16 + i*4 + j]; + } + } + // 2nd tile right behind the 1st? + c_tensor_s += BN; // BN has nothing to do with input tiles + } + +} - float kernel_weight = src0[kernel_offset + weight_idx]; - float input_value = src1[input_offset+i]; +__device__ __forceinline__ void prefetch_filter_tile(const float *pInputs, float *tiles, int filt_k){ + + int c_tensor = blockIdx.z*BK + (threadIdx.y*filt_k<<4) + threadIdx.x; // Iny*filt_k*4*4 + // each threadIdx.y corresponds to one channel; there are 8 different threadIdx.y so 8 channels + + //each thread (32 threads in x direction) loads 2 kernel tiles (32 in K direction apart) + // save the two tiles in a float[32] register, float[16] for each + + int acumm; + #pragma unroll + for(int i=0; i<4; i++){ + acumm = (i*filt_k<<2); + #pragma unroll + for(int j=0; j<4; j++){ + tiles[(i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor]; + tiles[16 + (i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor+BN]; + } + } +} - accumulator += kernel_weight * input_value; - } +__device__ __forceinline__ void prefetch_input_tile(const float *pInputs, float *tile, int in_h, + int in_w, int tw, int th, unsigned short mask){ + + // load one input tile + int tx = TW, ty = TH; + int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; + int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + + threadIdx.y*(in_h*in_w) - (in_w+1); + + int acumm,x; + + + if(mask==0xFFFF){ + #pragma unroll + for(int i=0; i<4; i++){ + acumm = i*in_w; + #pragma unroll + for(int j=0; j<4; j++){ + tile[(i<<2) + j] = pInputs[acumm + j + c_tensor]; + } + } + + } else { + for(int i=0; i<4; i++){ + acumm = i*in_w; + #pragma unroll + for(int j=0; j<4; j++){ + x = (i<<2) + j; + tile[x] = 0.f; + if(mask&(1< k1-1) m &= 0x0; //pad all zeros since this tile does not exist + }else if(tiles_dim_h % Y == 0){ + int k = in_w % TW; + int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 8*k1 tiles + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == k1-1) m &= (!(k%2))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > k1-1) m &= 0x0; //pad all zeros since this tile does not exist + }else{ + int kh = in_h % TH; + int kw = in_w % TW; + int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles + int kw1 = kw % 2 ? (kw+1)/2 : kw/2; + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == kh1-1) m &= (!(kh%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X > kh1-1) m &= 0x0; //pad all zeros since this tile does not exist + if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == kw1-1) m &= (!(kw%2))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > kw1-1) m &= 0x0; //pad all zeros since this tile does not exist + } + if(blockIdx.x==0 && (threadIdx.x % X) == 0) m &=0xeeee; // pad left col + + float img_tile[16]; // Prefetch input from GMEM + float filter_tile[32]; // Prefetch filter from GMEM + + float4 input_frag_mem[8]; //2*2(2*8/4) Data to do Outer Product + prefetch f. SMEM (double_buffer) + float4 filter_frag_mem[8]; //2*2 Data to do Outer Product + prefetch f. SMEM (double_buffer) + float4 accumulator[2][16] = {0.0f}; // Accumulators + + float4 *A_frag; // Input data pointer + int frag_offset = 2 * (BN*BC); // (2=8/4) SMEM input read offset + + float4 *B_frag; // Filter data pointer + int f_frag_offset = 2 * (BC*BK); // (2=8/4 with 4 being float4) SMEM filter read offset + + + float4 *input_frag = (float4*) input_frag_mem; + float4 *filter_frag = (float4*) filter_frag_mem; + + float4 *swap_filter; + float4 *swap_input; + + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_filter_tile(B, filter_tile, filt_k); + + float4 *input_frag_buffer = (float4*) (input_frag+4); + float4 *filter_frag_buffer = (float4*) (filter_frag+4); + + // Mainloop - iterates over the entire K dimension - not unrolled + for(int iter=0; iter>>(w, Ww, filt_k, filt_c, filt_h, filt_w); + + // each thread block will load 32 tiles (4x4) from the single image input + // we let X*Y = 32 and arbitraraly pick X = 4 and Y = 8 + Winograd_kernel<<>>(k, Ww, C, + tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); + + return cudaGetLastError(); +} + } -static void conv_transpose_1d_f32_f32_cuda( - const int s0, const int p0, const int d0, const int output_size, + +static void conv_winograd_stage0_f32_f32_cuda( + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, float * dst, + cudaStream_t stream) { + + + int64_t filt_k = src0_ne3; + int64_t filt_c = src0_ne2; + + FX<<>>(src0, dst, filt_k, filt_c, src0_ne1, src0_ne0); + +} + +static void conv_winograd_stage1_f16_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, + int tile_size, int tile_2d_s, const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, const float * src0, const float * src1, float * dst, cudaStream_t stream) { - const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE; - conv_transpose_1d_kernel<<>>( - s0,p0,d0,output_size, - src0_ne0, src0_ne1, src0_ne2, src0_ne3, - src1_ne0, src1_ne1, src1_ne2, src1_ne3, - dst_ne0, dst_ne1, dst_ne2, dst_ne3, - src0,src1, dst); + int64_t filt_k = src0_ne3; + int64_t in_c = src1_ne2; + int64_t in_h = src1_ne1; + int64_t in_w = src1_ne0; + int64_t filt_c = src1_ne0; + int64_t out_c = filt_k; + int64_t out_h = in_h; + int64_t out_w = in_w; + int smem_size = (16*BN*BC + 16*BC*BK)*4; + + Winograd_kernel<<>>(src1, src0, dst, + tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); } void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - - const ggml_tensor * src1 = dst->src[1]; - const float * src1_d = (const float *)src1->data; - + // const half * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + int id = ggml_cuda_get_device(); - GGML_ASSERT(src0->type == GGML_TYPE_F32); + // GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const int32_t * opts = (const int32_t *)dst->op_params; - - const int s0 = opts[0]; - const int p0 = 0;//opts[3]; - const int d0 = 1;//opts[4]; - - const int64_t kernel_size = ggml_nelements(src0); - const int64_t input_size = ggml_nelements(src1); - const int64_t output_size = ggml_nelements(dst); + ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); + if (src0->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + int64_t nle = ggml_nelements(src0); + src0_ddq_as_f32.alloc(nle); + const half * src0_dd = (const half *)src0->data; + to_fp32_cuda(src0_dd, src0_ddq_as_f32.get(), nle, stream); + } - conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + // GGML_ASSERT(ggml_is_contiguous(src0)); + const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get(); + + conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - src0_d, src1_d, dst_d, stream); + src0_ddf_i, dst_d, stream); } + void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -99,22 +797,28 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); - const int32_t * opts = (const int32_t *)dst->op_params; + const int m = 2; + const int r = 3; + const int tile_size = m+r-1; + int tiles_dim_w, tiles_dim_h; + + tiles_dim_w = ceil(ceil((double)(src1->ne[0]+2)/2)-1); + tiles_dim_h = ceil(ceil((double)(src1->ne[1]+2)/2)-1); - const int s0 = opts[0]; - const int p0 = 0;//opts[3]; - const int d0 = 1;//opts[4]; + int tile_2d_s = tile_size*tile_size; - const int64_t kernel_size = ggml_nelements(src0); - const int64_t input_size = ggml_nelements(src1); - const int64_t output_size = ggml_nelements(dst); + cudaMemcpyToSymbol(access_f_s, aux, 64*sizeof(int)); + cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int)); + cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int)); - conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, + conv_winograd_stage1_f16_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8, + tile_size, tile_2d_s, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], diff --git a/src/ggml-cuda/conv-winograd.cuh b/src/ggml-cuda/conv-winograd.cuh index 8c3c388a5..2569cf36f 100644 --- a/src/ggml-cuda/conv-winograd.cuh +++ b/src/ggml-cuda/conv-winograd.cuh @@ -37,729 +37,6 @@ const int tid[2][32] = { 2,3,6,7,10,11,14,15,18,19,22,23,26,27,30,31} }; -__device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag, float4 accumulator[][16]){ - accumulator[0][0].x += input_frag[0].x*filter_frag[0].x; - accumulator[0][0].y += input_frag[0].y*filter_frag[0].x; - accumulator[0][0].z += input_frag[0].z*filter_frag[0].x; - accumulator[0][0].w += input_frag[0].w*filter_frag[0].x; - - accumulator[0][1].x += input_frag[1].x*filter_frag[0].x; - accumulator[0][1].y += input_frag[1].y*filter_frag[0].x; - accumulator[0][1].z += input_frag[1].z*filter_frag[0].x; - accumulator[0][1].w += input_frag[1].w*filter_frag[0].x; - - accumulator[0][2].x += input_frag[0].x*filter_frag[0].y; - accumulator[0][2].y += input_frag[0].y*filter_frag[0].y; - accumulator[0][2].z += input_frag[0].z*filter_frag[0].y; - accumulator[0][2].w += input_frag[0].w*filter_frag[0].y; - - accumulator[0][3].x += input_frag[1].x*filter_frag[0].y; - accumulator[0][3].y += input_frag[1].y*filter_frag[0].y; - accumulator[0][3].z += input_frag[1].z*filter_frag[0].y; - accumulator[0][3].w += input_frag[1].w*filter_frag[0].y; - - accumulator[0][4].x += input_frag[0].x*filter_frag[0].z; - accumulator[0][4].y += input_frag[0].y*filter_frag[0].z; - accumulator[0][4].z += input_frag[0].z*filter_frag[0].z; - accumulator[0][4].w += input_frag[0].w*filter_frag[0].z; - - accumulator[0][5].x += input_frag[1].x*filter_frag[0].z; - accumulator[0][5].y += input_frag[1].y*filter_frag[0].z; - accumulator[0][5].z += input_frag[1].z*filter_frag[0].z; - accumulator[0][5].w += input_frag[1].w*filter_frag[0].z; - - accumulator[0][6].x += input_frag[0].x*filter_frag[0].w; - accumulator[0][6].y += input_frag[0].y*filter_frag[0].w; - accumulator[0][6].z += input_frag[0].z*filter_frag[0].w; - accumulator[0][6].w += input_frag[0].w*filter_frag[0].w; - - accumulator[0][7].x += input_frag[1].x*filter_frag[0].w; - accumulator[0][7].y += input_frag[1].y*filter_frag[0].w; - accumulator[0][7].z += input_frag[1].z*filter_frag[0].w; - accumulator[0][7].w += input_frag[1].w*filter_frag[0].w; - - // - accumulator[0][8].x += input_frag[0].x*filter_frag[1].x; - accumulator[0][8].y += input_frag[0].y*filter_frag[1].x; - accumulator[0][8].z += input_frag[0].z*filter_frag[1].x; - accumulator[0][8].w += input_frag[0].w*filter_frag[1].x; - - accumulator[0][9].x += input_frag[1].x*filter_frag[1].x; - accumulator[0][9].y += input_frag[1].y*filter_frag[1].x; - accumulator[0][9].z += input_frag[1].z*filter_frag[1].x; - accumulator[0][9].w += input_frag[1].w*filter_frag[1].x; - - accumulator[0][10].x += input_frag[0].x*filter_frag[1].y; - accumulator[0][10].y += input_frag[0].y*filter_frag[1].y; - accumulator[0][10].z += input_frag[0].z*filter_frag[1].y; - accumulator[0][10].w += input_frag[0].w*filter_frag[1].y; - - accumulator[0][11].x += input_frag[1].x*filter_frag[1].y; - accumulator[0][11].y += input_frag[1].y*filter_frag[1].y; - accumulator[0][11].z += input_frag[1].z*filter_frag[1].y; - accumulator[0][11].w += input_frag[1].w*filter_frag[1].y; - - accumulator[0][12].x += input_frag[0].x*filter_frag[1].z; - accumulator[0][12].y += input_frag[0].y*filter_frag[1].z; - accumulator[0][12].z += input_frag[0].z*filter_frag[1].z; - accumulator[0][12].w += input_frag[0].w*filter_frag[1].z; - - accumulator[0][13].x += input_frag[1].x*filter_frag[1].z; - accumulator[0][13].y += input_frag[1].y*filter_frag[1].z; - accumulator[0][13].z += input_frag[1].z*filter_frag[1].z; - accumulator[0][13].w += input_frag[1].w*filter_frag[1].z; - - accumulator[0][14].x += input_frag[0].x*filter_frag[1].w; - accumulator[0][14].y += input_frag[0].y*filter_frag[1].w; - accumulator[0][14].z += input_frag[0].z*filter_frag[1].w; - accumulator[0][14].w += input_frag[0].w*filter_frag[1].w; - - accumulator[0][15].x += input_frag[1].x*filter_frag[1].w; - accumulator[0][15].y += input_frag[1].y*filter_frag[1].w; - accumulator[0][15].z += input_frag[1].z*filter_frag[1].w; - accumulator[0][15].w += input_frag[1].w*filter_frag[1].w; - - ////// - accumulator[1][0].x += input_frag[2].x*filter_frag[2].x; - accumulator[1][0].y += input_frag[2].y*filter_frag[2].x; - accumulator[1][0].z += input_frag[2].z*filter_frag[2].x; - accumulator[1][0].w += input_frag[2].w*filter_frag[2].x; - - accumulator[1][1].x += input_frag[3].x*filter_frag[2].x; - accumulator[1][1].y += input_frag[3].y*filter_frag[2].x; - accumulator[1][1].z += input_frag[3].z*filter_frag[2].x; - accumulator[1][1].w += input_frag[3].w*filter_frag[2].x; - - accumulator[1][2].x += input_frag[2].x*filter_frag[2].y; - accumulator[1][2].y += input_frag[2].y*filter_frag[2].y; - accumulator[1][2].z += input_frag[2].z*filter_frag[2].y; - accumulator[1][2].w += input_frag[2].w*filter_frag[2].y; - - accumulator[1][3].x += input_frag[3].x*filter_frag[2].y; - accumulator[1][3].y += input_frag[3].y*filter_frag[2].y; - accumulator[1][3].z += input_frag[3].z*filter_frag[2].y; - accumulator[1][3].w += input_frag[3].w*filter_frag[2].y; - - accumulator[1][4].x += input_frag[2].x*filter_frag[2].z; - accumulator[1][4].y += input_frag[2].y*filter_frag[2].z; - accumulator[1][4].z += input_frag[2].z*filter_frag[2].z; - accumulator[1][4].w += input_frag[2].w*filter_frag[2].z; - - accumulator[1][5].x += input_frag[3].x*filter_frag[2].z; - accumulator[1][5].y += input_frag[3].y*filter_frag[2].z; - accumulator[1][5].z += input_frag[3].z*filter_frag[2].z; - accumulator[1][5].w += input_frag[3].w*filter_frag[2].z; - - accumulator[1][6].x += input_frag[2].x*filter_frag[2].w; - accumulator[1][6].y += input_frag[2].y*filter_frag[2].w; - accumulator[1][6].z += input_frag[2].z*filter_frag[2].w; - accumulator[1][6].w += input_frag[2].w*filter_frag[2].w; - - accumulator[1][7].x += input_frag[3].x*filter_frag[2].w; - accumulator[1][7].y += input_frag[3].y*filter_frag[2].w; - accumulator[1][7].z += input_frag[3].z*filter_frag[2].w; - accumulator[1][7].w += input_frag[3].w*filter_frag[2].w; - - // - accumulator[1][8].x += input_frag[2].x*filter_frag[3].x; - accumulator[1][8].y += input_frag[2].y*filter_frag[3].x; - accumulator[1][8].z += input_frag[2].z*filter_frag[3].x; - accumulator[1][8].w += input_frag[2].w*filter_frag[3].x; - - accumulator[1][9].x += input_frag[3].x*filter_frag[3].x; - accumulator[1][9].y += input_frag[3].y*filter_frag[3].x; - accumulator[1][9].z += input_frag[3].z*filter_frag[3].x; - accumulator[1][9].w += input_frag[3].w*filter_frag[3].x; - - accumulator[1][10].x += input_frag[2].x*filter_frag[3].y; - accumulator[1][10].y += input_frag[2].y*filter_frag[3].y; - accumulator[1][10].z += input_frag[2].z*filter_frag[3].y; - accumulator[1][10].w += input_frag[2].w*filter_frag[3].y; - - accumulator[1][11].x += input_frag[3].x*filter_frag[3].y; - accumulator[1][11].y += input_frag[3].y*filter_frag[3].y; - accumulator[1][11].z += input_frag[3].z*filter_frag[3].y; - accumulator[1][11].w += input_frag[3].w*filter_frag[3].y; - - accumulator[1][12].x += input_frag[2].x*filter_frag[3].z; - accumulator[1][12].y += input_frag[2].y*filter_frag[3].z; - accumulator[1][12].z += input_frag[2].z*filter_frag[3].z; - accumulator[1][12].w += input_frag[2].w*filter_frag[3].z; - - accumulator[1][13].x += input_frag[3].x*filter_frag[3].z; - accumulator[1][13].y += input_frag[3].y*filter_frag[3].z; - accumulator[1][13].z += input_frag[3].z*filter_frag[3].z; - accumulator[1][13].w += input_frag[3].w*filter_frag[3].z; - - accumulator[1][14].x += input_frag[2].x*filter_frag[3].w; - accumulator[1][14].y += input_frag[2].y*filter_frag[3].w; - accumulator[1][14].z += input_frag[2].z*filter_frag[3].w; - accumulator[1][14].w += input_frag[2].w*filter_frag[3].w; - - accumulator[1][15].x += input_frag[3].x*filter_frag[3].w; - accumulator[1][15].y += input_frag[3].y*filter_frag[3].w; - accumulator[1][15].z += input_frag[3].z*filter_frag[3].w; - accumulator[1][15].w += input_frag[3].w*filter_frag[3].w; - } - -extern "C" -{ - -__device__ __forceinline__ void transform_output_tile(float *pOutputs, float2 *C_tile, float2 *At, - int round, int c_tensor, int c_glb_offset, int i1, int i2, - unsigned short mask1, unsigned short mask2, int out_w) -{ - - c_tensor += (((round)/2)*32 + ((round)%2)*2)*c_glb_offset; - int x, x1; - - #pragma unroll - for(int j=0; j<4; j++){ - - At[j].x = C_tile[j].x + C_tile[4+j].x + C_tile[8+j].x; - At[j].y = C_tile[j].y + C_tile[4+j].y + C_tile[8+j].y; - - At[4+j].x = C_tile[4+j].x - C_tile[8+j].x - C_tile[12+j].x; - At[4+j].y = C_tile[4+j].y - C_tile[8+j].y - C_tile[12+j].y; - - } - - #pragma unroll - for(int i=0; i<2; i++){ - x = i*4; - x1 = i*((out_w-(out_w%2)) + (out_w%2)/2); - - if(mask1&(1<<(i*2))){ - pOutputs[x1 + c_tensor + i1] = At[x].x + At[x+1].x + At[x+2].x; - } - if(mask2&(1<<(i*2))){ - pOutputs[x1 + c_tensor + i2] = At[x].y + At[x+1].y + At[x+2].y; - } - if(mask1&(1<<(i*2+1))){ - pOutputs[x1 + c_tensor + i1 + 1] = At[x+1].x - At[x+2].x - At[x+3].x; - } - if(mask2&(1<<(i*2+1))){ - pOutputs[x1 + c_tensor + i2 + 1] = At[x+1].y - At[x+2].y - At[x+3].y; - } - } -} - -__device__ __forceinline__ unsigned short get_mask(int idd, int tiles_dim_w, int tiles_dim_h, - int tw, int th, int out_w, int out_h){ - - unsigned short mask = 0x000F; - // if((blockIdx.y/tiles_dim)==(tiles_dim-1) && out_w%2) mask&=0x0003; // pad bottom row - // if(!((blockIdx.y+1)%tiles_dim) && out_w%2) mask&=0X0005; // pad right col - // if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row - // if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col - if(tiles_dim_w % tw == 0 && tiles_dim_h % th == 0){ - if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row - if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col - }else if(tiles_dim_w % tw == 0){ - int k = out_h % TH; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles - if(blockIdx.y==gridDim.y-1 && (idd / tw) == k1-1 && k%2) mask&=0x0003; // pad bottom row - if(blockIdx.y==gridDim.y-1 && (idd / tw) > k1-1) mask &= 0x0; //pad all zeros since this tile does not exist - }else if(tiles_dim_h % th == 0){ - int k = out_w % TW; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles - if(blockIdx.x==gridDim.x-1 && (idd % tw) == k1-1 && k%2) mask&=0X0005; // pad right col - if(blockIdx.x==gridDim.x-1 && (idd % tw) > k1-1) mask&=0X0; // pad all zeroes - }else{ - int kh = out_h % TH; - int kw = out_w % TW; - int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles - int kw1 = kw % 2 ? (kw+1)/2 : kw/2; - if(blockIdx.y==gridDim.y-1 && (idd / tw) == kh1-1 && kh%2) mask&=0x0003; // pad bottom row - if(blockIdx.x==gridDim.x-1 && (idd % tw) == kw1-1 && kw%2) mask&=0X0005; // pad right col - if(blockIdx.y==gridDim.y-1 && (idd / tw) > kh1-1) mask &= 0x0; //pad all zeros since this tile does not exist - if(blockIdx.x==gridDim.x-1 && (idd % tw) > kw1-1) mask &= 0X0; // pad all zeroes - } - return mask; -} - -__device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float *C, -int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, -float4 *input_frag_mem, float4* filter_frag_mem){ - - float2 *output_smem = (float2 *) shared_mem; - float2 *accumulator = (float2 *) acumm_smem; - float2 *C_out = (float2*)C; - - float2 *C_tile = (float2*) input_frag_mem; - float2 *At = (float2*) filter_frag_mem; - - int idd1 = tileid[0][threadIdx.x]; - int id1 = (idd1 % tw) * 2 + (idd1 / tw) * out_w * 2; - int idd2 = tileid[1][threadIdx.x]; - int id2 = (idd2 % tw) * 2 + (idd2 / tw) * out_w * 2; - - // unsigned short mask1 = 0x000F; - unsigned short mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); - unsigned short mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); - - // output transpose step - int t=0; - int acumm1, acumm2; - // For transposing - //acumm1 = access_s_out[Inx]; //* 4 - acumm1 = ((threadIdx.x%8)/2)*34 + threadIdx.x%2 + (threadIdx.x/16)*2 + ((threadIdx.x/8)%2)*8; - acumm2 = acumm1+4; - - int acumm4 = BN_p*16 ; //*4 - int idx = threadIdx.y * BN_p; - int idx2 = idx + BN_p*8; //(BN_p*2 *8)/2 - - // For transformating - int offset = BN_p *2; //*2/2 - int init = ( (threadIdx.y/4)*BN_p*16 + (threadIdx.y%4)*(32+2) ) *2 + threadIdx.x; - - int c_glb_offset = out_h*out_w; - // int c_tensor = blockIdx.z*c_glb_offset*BK + (blockIdx.y%tiles_dim)*2 + (blockIdx.y/tiles_dim)*out_w*2 + - // blockIdx.x*BN + (threadIdx.x%16)*2+ - // ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; - - int tx = TW, ty = TH; - // int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; - // int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + - // threadIdx.y*(in_h*in_w) - (in_w+1); - - int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * tx + blockIdx.y * out_w * ty + - // (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 + - ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; - - // c_tensor/=2; - - - int target = 16; - - #pragma unroll - for(int round=0; round<4; round++){ - - *( (float2*) (output_smem + idx + acumm1) ) = *(accumulator+t); - *( (float2*) (output_smem + idx + acumm1 + 16) ) = *(accumulator+t+1); // float 4, t - *( (float2*) (output_smem + idx + acumm2) ) = *(accumulator+t+2); - *( (float2*) (output_smem + idx + acumm2 + 16) ) = *(accumulator+t+3); // float 4, t+1 - - - *( (float2*) (output_smem + idx2 + acumm1) ) = *(accumulator+t+32); - *( (float2*) (output_smem + idx2 + acumm1 + 16) ) = *(accumulator+t+33); // float 4, t+16 - *( (float2*) (output_smem + idx2 + acumm2) ) = *(accumulator+t+34); - *( (float2*) (output_smem + idx2 + acumm2 + 16) ) = *(accumulator+t+35); // float 4, t+17 - - // the above 8 float2 will be consumed by theadIdx.y = [0,1,2,3] - - // the following 8 float2 will be consumed by theadIdx.y = [4,5,6,7] - - *( (float2*) (output_smem + idx + acumm4 + acumm1) ) = *(accumulator+t+4); - *( (float2*) (output_smem + idx + acumm4 + acumm1 + 16) ) = *(accumulator+t+5); // float 4, t+2 - *( (float2*) (output_smem + idx + acumm4 + acumm2) ) = *(accumulator+t+6); - *( (float2*) (output_smem + idx + acumm4 + acumm2 + 16) ) = *(accumulator+t+7); // float 4, t+3 - - *( (float2*) (output_smem + idx2 + acumm4 + acumm1) ) = *(accumulator+t+36); - *( (float2*) (output_smem + idx2 + acumm4 + acumm1 + 16) ) = *(accumulator+t+37); // float 4, t+18 - *( (float2*) (output_smem + idx2 + acumm4 + acumm2) ) = *(accumulator+t+38); - *( (float2*) (output_smem + idx2 + acumm4 + acumm2 + 16) ) = *(accumulator+t+39); // float 4, t+19 - - - - t+=8; - - __syncthreads(); - - - - // for output transformation, the role of threadIdx.y changes again: - // in the main loop, different threadIdx.y deal with different element of the 4x4 tile - // here, they are for 4 different groups of lane ids from optSTS64 layout - // for init (and init+32), we need to identify its tile number (0-31) within the supertile - // first, from init, find out from which threadIdx.x it comes. - - // now we got l, which is the land id which computed accumulated sum for the tile element - // each lane id (or threadIdx.x) computed 8 tiles which are distributed into 4 locations spreading - // over the smem. We need to find which of the 8 the current tile is. - // use tileid table to figure out - // int id1 = tileid[0][l]; - - - // for 2nd tile - - - for(int i=0; i<16; i++){ - C_tile[i].x = shared_mem[i*offset + init]; - C_tile[i].y = shared_mem[i*offset + init + 32]; - - } - - - // transform output tiles - transform_output_tile(C, C_tile, At, round, c_tensor, c_glb_offset, id1, id2, mask1, mask2, out_w); - __syncthreads(); - } -} - - -// Set of functions per row in Gw product -__device__ float f_row1(float *Gw, int j){ - return Gw[j]; - } - __device__ float f_row2(float *Gw, int j){ - return 0.5*(Gw[j] + Gw[6+j] + Gw[3+j]); - } - __device__ float f_row3(float *Gw, int j){ - return 0.5*(Gw[j] + Gw[6+j] - Gw[3+j]); - } - __device__ float f_row4(float *Gw, int j){ - return Gw[6+j]; - } - // Set of functions per column in GwGt product - __device__ float f_col1(float *Gw, int j){ - return Gw[j]; - } - __device__ float f_col2(float *Gw, int j){ - return 0.5*(Gw[j] + Gw[j+2] + Gw[j+1]); - } - __device__ float f_col3(float *Gw, int j){ - return 0.5*(Gw[j] + Gw[j+2] - Gw[j+1]); - } - __device__ float f_col4(float *Gw, int j){ - return Gw[j+2]; - } - - typedef float(*pointFunction_t)(float *, int); - - __global__ void FX(float *pInputs, float *pOutputs, int filt_k, - int filt_c, int filt_h, int filt_w, int alpha){ - int Inx = threadIdx.x, Iny = threadIdx.y; - int TileX = blockIdx.x, TileY = blockIdx.y; - - int c_glb_offset = filt_k*filt_h*filt_w; - int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx; - int c_glb_offset_s = filt_k*4*4; - int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx; - - float Gw[21]; //9+12. In registers - float *Gw_buffer = Gw+9; - - pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4}; - pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; - - for(int bk=0; bk= BC) return; - - // each thread in row 0 puts its first element of 1st filter tile(loaded by the thread) in smem - // taking 32 slots - // then puts its first element of 2nd filter tile immediately after, taking another 32 slots - // then followed by threads in row 1, 2.. until 7 - - // Note the next element is BK*BC (8*64) slots away, then another BK*BC .... - // for every 64 values, the first 32 belongs to filter tile 1, the next 32 for filter tile 2 - - for(int k=0; k<2; k++){ // prefetch 2 filter tiles/thread - for(int i=0; i<4; i++){ - #pragma unroll - for(int j=0; j<4; j++){ - pOutputs[c_tensor_s + i*c_offset_s*4 + j*c_offset_s] = tiles[k*16 + i*4 + j]; - } - } - // 2nd tile right behind the 1st? - c_tensor_s += BN; // BN has nothing to do with input tiles - } - -} - -__device__ __forceinline__ void prefetch_filter_tile(float *pInputs, float *tiles, int filt_k){ - - int c_tensor = blockIdx.z*BK + (threadIdx.y*filt_k<<4) + threadIdx.x; // Iny*filt_k*4*4 - // each threadIdx.y corresponds to one channel; there are 8 different threadIdx.y so 8 channels - - //each thread (32 threads in x direction) loads 2 kernel tiles (32 in K direction apart) - // save the two tiles in a float[32] register, float[16] for each - - int acumm; - #pragma unroll - for(int i=0; i<4; i++){ - acumm = (i*filt_k<<2); - #pragma unroll - for(int j=0; j<4; j++){ - tiles[(i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor]; - tiles[16 + (i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor+BN]; - } - } -} - -__device__ __forceinline__ void prefetch_input_tile(float *pInputs, float *tile, int in_h, - int in_w, int tw, int th, unsigned short mask){ - - // load one input tile - int tx = TW, ty = TH; - int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; - int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + - threadIdx.y*(in_h*in_w) - (in_w+1); - - int acumm,x; - - - if(mask==0xFFFF){ - #pragma unroll - for(int i=0; i<4; i++){ - acumm = i*in_w; - #pragma unroll - for(int j=0; j<4; j++){ - tile[(i<<2) + j] = pInputs[acumm + j + c_tensor]; - } - } - - } else { - for(int i=0; i<4; i++){ - acumm = i*in_w; - #pragma unroll - for(int j=0; j<4; j++){ - x = (i<<2) + j; - tile[x] = 0.f; - if(mask&(1< k1-1) m &= 0x0; //pad all zeros since this tile does not exist - }else if(tiles_dim_h % Y == 0){ - int k = in_w % TW; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 8*k1 tiles - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == k1-1) m &= (!(k%2))?(0x7777):(0x3333); // pad right col or right 2 cols - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > k1-1) m &= 0x0; //pad all zeros since this tile does not exist - }else{ - int kh = in_h % TH; - int kw = in_w % TW; - int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles - int kw1 = kw % 2 ? (kw+1)/2 : kw/2; - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == kh1-1) m &= (!(kh%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X > kh1-1) m &= 0x0; //pad all zeros since this tile does not exist - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == kw1-1) m &= (!(kw%2))?(0x7777):(0x3333); // pad right col or right 2 cols - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > kw1-1) m &= 0x0; //pad all zeros since this tile does not exist - } - if(blockIdx.x==0 && (threadIdx.x % X) == 0) m &=0xeeee; // pad left col - - float img_tile[16]; // Prefetch input from GMEM - float filter_tile[32]; // Prefetch filter from GMEM - - float4 input_frag_mem[8]; //2*2(2*8/4) Data to do Outer Product + prefetch f. SMEM (double_buffer) - float4 filter_frag_mem[8]; //2*2 Data to do Outer Product + prefetch f. SMEM (double_buffer) - float4 accumulator[2][16] = {0.0f}; // Accumulators - - float4 *A_frag; // Input data pointer - int frag_offset = 2 * (BN*BC); // (2=8/4) SMEM input read offset - - float4 *B_frag; // Filter data pointer - int f_frag_offset = 2 * (BC*BK); // (2=8/4 with 4 being float4) SMEM filter read offset - - - float4 *input_frag = (float4*) input_frag_mem; - float4 *filter_frag = (float4*) filter_frag_mem; - - float4 *swap_filter; - float4 *swap_input; - - prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); - prefetch_filter_tile(B, filter_tile, filt_k); - - float4 *input_frag_buffer = (float4*) (input_frag+4); - float4 *filter_frag_buffer = (float4*) (filter_frag+4); - - // Mainloop - iterates over the entire K dimension - not unrolled - for(int iter=0; iter>>(w, Ww, filt_k, filt_c, filt_h, filt_w, alpha); - - // each thread block will load 32 tiles (4x4) from the single image input - // we let X*Y = 32 and arbitraraly pick X = 4 and Y = 8 - Winograd_kernel<<>>(k, Ww, C, - tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); - - return cudaGetLastError(); -} - -} void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml.c b/src/ggml.c index 65d8c79a0..389878991 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3026,7 +3026,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -7184,7 +7184,7 @@ struct ggml_tensor * ggml_winograd_stage0( is_node = true; } - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 16, a->ne[2], a->ne[3], 1); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 4, 4, a->ne[2], a->ne[3]); result->op = GGML_OP_WINOGRAD_STAGE0; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -15195,6 +15195,23 @@ static void ggml_compute_forward_conv_transpose_1d( } } + +static void ggml_compute_forward_winograd_stage0( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + GGML_ASSERT(false && " CPU backend not implemented!"); + return; +} + +static void ggml_compute_forward_winograd_stage1( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + GGML_ASSERT(false && " CPU backend not implemented!"); + return; +} + // ggml_compute_forward_im2col_f32 // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] @@ -17891,6 +17908,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_transpose_1d(params, tensor); } break; + case GGML_OP_WINOGRAD_STAGE0: + { + ggml_compute_forward_winograd_stage0(params, tensor); + } break; + case GGML_OP_WINOGRAD_STAGE1: + { + ggml_compute_forward_winograd_stage1(params, tensor); + } break; case GGML_OP_IM2COL: { ggml_compute_forward_im2col(params, tensor); @@ -18964,6 +18989,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_WINOGRAD_STAGE0: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } + case GGML_OP_WINOGRAD_STAGE1: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_POOL_1D: { GGML_ABORT("fatal error"); // TODO: not implemented diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp new file mode 100644 index 000000000..371277399 --- /dev/null +++ b/tests/test-conv2d-winograd.cpp @@ -0,0 +1,395 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + +void load_model(test_model & model, bool use_gpu = false) { + // create data + int KW = 3, KH = 3, IC = 10, OC = 10; + int IW = 8, IH = 6, N = 1; + + // Initialize adata + std::vector adata(KW * KH * IC * OC); + for (int i = 0; i < KW * KH * IC * OC; i++) { + adata[i] = 2.5f; + } + + // Convert adata to fp16 format + std::vector hadata(KW * KH * IC * OC); + ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); + + // Initialize bdata + std::vector bdata(IW * IH * IC * N); + for (int i = 0; i < IW * IH * IC * N; i++) { + bdata[i] = 1.5f; + } + + size_t buffer_size = 0; + { + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a + buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUDA + if (use_gpu) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); + model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); + + // create a allocator + struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); + } +} + +struct ggml_cgraph * build_graph(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + // split conv2d in fundamental methods for test unit + struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); + ggml_set_name(im2col_0, "im2col_res"); + ggml_build_forward_expand(gf, im2col_0); + + // recalculate for avoid fragmentation + struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(conv2d_res, "conv2d_res"); + ggml_build_forward_expand(gf, conv2d_res); + + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * compute_graph(const test_model & model, ggml_gallocr_t allocr) { + struct ggml_cgraph * gf = build_graph(model); + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + ggml_backend_graph_compute(model.backend, gf); + + //ggml_graph_print(gf); + + return gf; +} + +int main(void) +{ + ggml_time_init(); + + test_model model; + load_model(model, true); + + ggml_gallocr_t allocr = NULL; + + { + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0); + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + } + + struct ggml_cgraph * gf_res = compute_graph(model, allocr); + + struct ggml_tensor * im2col_res = NULL; + struct ggml_tensor * conv2d_res = NULL; + + for(int i = 0; i < ggml_graph_n_nodes(gf_res); ++i) { + if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "im2col_res") == 0) { + im2col_res = ggml_graph_node(gf_res, i); + } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "conv2d_res") == 0) { + conv2d_res = ggml_graph_node(gf_res, i); + } + } + + std::vector im2col_data(ggml_nelements(im2col_res)); + std::vector conv2d_data(ggml_nelements(conv2d_res)); + + ggml_backend_tensor_get(im2col_res, im2col_data.data(), 0, ggml_nbytes(im2col_res)); + ggml_backend_tensor_get(conv2d_res, conv2d_data.data(), 0, ggml_nbytes(conv2d_res)); + + const int n_conv2d_test = 480; + const int n_im2col_test = 4320; + + float expected_conv2d [n_conv2d_test] = { + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, + 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f }; + + uint16_t expected_im2col[n_conv2d_test] = { + 0, 0, 0, 0, 15872, 15872, 0, 15872, + 15872, 0, 0, 0, 0, 15872, 15872, 0, + 15872, 15872, 0, 0, 0, 0, 15872, 15872, + 0, 15872, 15872, 0, 0, 0, 0, 15872, + 15872, 0, 15872, 15872, 0, 0, 0, 0, + 15872, 15872, 0, 15872, 15872, 0, 0, 0, + 0, 15872, 15872, 0, 15872, 15872, 0, 0, + 0, 0, 15872, 15872, 0, 15872, 15872, 0, + 0, 0, 0, 15872, 15872, 0, 15872, 15872, + 0, 0, 0, 0, 15872, 15872, 0, 15872, + 15872, 0, 0, 0, 0, 15872, 15872, 0, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0, + 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, + 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, + 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, + 0, 0, 0, 15872, 15872, 15872, 15872, 15872, + 15872, 0, 0, 0, 15872, 15872, 15872, 15872, + 15872, 15872, 0, 0, 0, 15872, 15872, 15872, + 15872, 15872, 15872, 0, 0, 0, 15872, 15872, + 15872, 15872, 15872, 15872, 0, 0, 0, 15872, + 15872, 15872, 15872, 15872, 15872, 0, 0, 0 + }; + + printf("\nPerforming test:\n"); + + bool passed = true; + for(int i = 0; i < n_conv2d_test; i++) { + if( + im2col_data[i] != expected_im2col[i]) { + passed = false; + break; + } + } + + printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + passed = true; + for(int i = 0; i < n_conv2d_test; i++) { + if(conv2d_data[i] != expected_conv2d[i]) { + passed = false; + break; + } + } + + printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + ggml_free(model.ctx); + + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + return 0; +} From 3d804665ae978d8a90fceac31f160c7c96a1789a Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 27 Sep 2024 13:43:57 -0400 Subject: [PATCH 05/23] test now passed; for some reason, ggml_conv_2d didn't output correct results --- src/CMakeLists.txt | 4 +- src/ggml-cuda/conv-winograd.cu | 22 ++-- src/ggml-cuda/conv-winograd.cuh | 2 +- src/ggml.c | 12 +-- tests/CMakeLists.txt | 9 ++ tests/test-conv2d-winograd.cpp | 185 +++++--------------------------- 6 files changed, 60 insertions(+), 174 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cbc349500..83c6cb2ce 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -285,9 +285,9 @@ if (GGML_CUDA) # 61 == integer CUDA intrinsics # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;86") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;86") #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index eaa2b9c42..60587db76 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -386,6 +386,8 @@ __device__ float f_row1(float *Gw, int j){ __global__ void FX(const float *pInputs, float *pOutputs, int filt_k, int filt_c, int filt_h, int filt_w){ + + // assumes CHWK layout int Inx = threadIdx.x, Iny = threadIdx.y; int TileX = blockIdx.x, TileY = blockIdx.y; @@ -725,14 +727,14 @@ static void conv_winograd_stage0_f32_f32_cuda( cudaStream_t stream) { - int64_t filt_k = src0_ne3; - int64_t filt_c = src0_ne2; + int64_t filt_k = src0_ne0; + int64_t filt_c = src0_ne3; - FX<<>>(src0, dst, filt_k, filt_c, src0_ne1, src0_ne0); + FX<<>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1); } -static void conv_winograd_stage1_f16_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, +static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, int tile_size, int tile_2d_s, const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, @@ -740,16 +742,20 @@ static void conv_winograd_stage1_f16_f32_cuda(int tiles_dim_w, int tiles_dim_h, const float * src0, const float * src1, float * dst, cudaStream_t stream) { - int64_t filt_k = src0_ne3; + int64_t filt_k = src0_ne0; int64_t in_c = src1_ne2; int64_t in_h = src1_ne1; int64_t in_w = src1_ne0; - int64_t filt_c = src1_ne0; + int64_t filt_c = src0_ne3; int64_t out_c = filt_k; int64_t out_h = in_h; int64_t out_w = in_w; int smem_size = (16*BN*BC + 16*BC*BK)*4; + printf("A %d, %d\n", filt_k, filt_c); + printf("B %d, %d, %d \n", in_c, in_h, in_w); + printf("C %d, %d, %d \n", out_c, out_h, out_w); + Winograd_kernel<<>>(src1, src0, dst, tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); } @@ -816,8 +822,8 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * cudaMemcpyToSymbol(access_f_s, aux, 64*sizeof(int)); cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int)); cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int)); - - conv_winograd_stage1_f16_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8, + printf(" %d, %d, %d \n", tiles_dim_w, tiles_dim_h, tile_size); + conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8, tile_size, tile_2d_s, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], diff --git a/src/ggml-cuda/conv-winograd.cuh b/src/ggml-cuda/conv-winograd.cuh index 2569cf36f..39bc7002c 100644 --- a/src/ggml-cuda/conv-winograd.cuh +++ b/src/ggml-cuda/conv-winograd.cuh @@ -1,6 +1,6 @@ #include "common.cuh" -// #define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256 + #define BC 8 #define BN 32 #define BK 64 diff --git a/src/ggml.c b/src/ggml.c index 389878991..34e0a165c 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -7179,12 +7179,12 @@ struct ggml_tensor * ggml_winograd_stage0( struct ggml_context * ctx, struct ggml_tensor * a) { bool is_node = false; - GGML_ASSERT(a->ne[0] == 3 && a->ne[1] == 3); // kernel should be 3x3 + if (a->grad) { is_node = true; } - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 4, 4, a->ne[2], a->ne[3]); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], 4, 4, a->ne[3]); result->op = GGML_OP_WINOGRAD_STAGE0; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7208,7 +7208,7 @@ struct ggml_tensor * ggml_winograd_stage1( int OW = b->ne[0]; int OH = b->ne[1]; - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, OW, OH, a->ne[3] /* OC */, 1); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, OW, OH, a->ne[0] /* OC */, 1); result->op = GGML_OP_WINOGRAD_STAGE1; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7222,14 +7222,14 @@ struct ggml_tensor * ggml_conv_2d_3x3( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b){ - + GGML_ASSERT(a->ne[0] == 3 && a->ne[1] == 3); // kernel should be 3x3 GGML_ASSERT(b->ne[3] == 1); // only works for 1 input image GGML_ASSERT(b->ne[2] == a->ne[2]); // number of channels must match if(a->ne[3] % 64 != 0 || a->ne[2] % 8 != 0) // only works for the number of filters is a multiple of 64 return ggml_conv_2d(ctx, a, b, 1, 1, 1, 1, 1, 1); // and the number of channels is a multiple of 8 - - struct ggml_tensor* W = ggml_winograd_stage0(ctx, a); + struct ggml_tensor* ra = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 3, 0)); // [N, OC, OH, OW] + struct ggml_tensor* W = ggml_winograd_stage0(ctx, ra); struct ggml_tensor * result = ggml_winograd_stage1(ctx, W, b); return result; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dfa649209..b5048656c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -408,6 +408,15 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml) add_test(NAME ${TEST_TARGET} COMMAND $) set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") +# +# test-conv2d-wino + +set(TEST_TARGET test-conv2d-winograd) +add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) +add_test(NAME ${TEST_TARGET} COMMAND $) +set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + # # test-mul-mat diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp index 371277399..bd5efa3bd 100644 --- a/tests/test-conv2d-winograd.cpp +++ b/tests/test-conv2d-winograd.cpp @@ -36,8 +36,8 @@ struct test_model { void load_model(test_model & model, bool use_gpu = false) { // create data - int KW = 3, KH = 3, IC = 10, OC = 10; - int IW = 8, IH = 6, N = 1; + int KW = 3, KH = 3, IC = 32, OC = 64; + int IW = 28, IH = 40, N = 1; // Initialize adata std::vector adata(KW * KH * IC * OC); @@ -157,16 +157,21 @@ struct ggml_cgraph * build_graph(const test_model& model) { int d0 = 1; int d1 = 1; - // split conv2d in fundamental methods for test unit - struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); - ggml_set_name(im2col_0, "im2col_res"); - ggml_build_forward_expand(gf, im2col_0); + // recalculate for avoid fragmentation struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); ggml_set_name(conv2d_res, "conv2d_res"); ggml_build_forward_expand(gf, conv2d_res); + int64_t *ne = conv2d_res->ne; + printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); + ggml_set_name(wino_res, "wino_res"); + ggml_build_forward_expand(gf, wino_res); + ne = wino_res->ne; + printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); ggml_free(ctx0); return gf; } @@ -218,173 +223,39 @@ int main(void) struct ggml_cgraph * gf_res = compute_graph(model, allocr); - struct ggml_tensor * im2col_res = NULL; + struct ggml_tensor * wino_res = NULL; struct ggml_tensor * conv2d_res = NULL; for(int i = 0; i < ggml_graph_n_nodes(gf_res); ++i) { - if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "im2col_res") == 0) { - im2col_res = ggml_graph_node(gf_res, i); + if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "wino_res") == 0) { + wino_res = ggml_graph_node(gf_res, i); } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "conv2d_res") == 0) { conv2d_res = ggml_graph_node(gf_res, i); } } - std::vector im2col_data(ggml_nelements(im2col_res)); + std::vector wino_data(ggml_nelements(wino_res)); std::vector conv2d_data(ggml_nelements(conv2d_res)); - ggml_backend_tensor_get(im2col_res, im2col_data.data(), 0, ggml_nbytes(im2col_res)); + ggml_backend_tensor_get(wino_res, wino_data.data(), 0, ggml_nbytes(wino_res)); ggml_backend_tensor_get(conv2d_res, conv2d_data.data(), 0, ggml_nbytes(conv2d_res)); - const int n_conv2d_test = 480; - const int n_im2col_test = 4320; - - float expected_conv2d [n_conv2d_test] = { - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f, - 150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f }; - - uint16_t expected_im2col[n_conv2d_test] = { - 0, 0, 0, 0, 15872, 15872, 0, 15872, - 15872, 0, 0, 0, 0, 15872, 15872, 0, - 15872, 15872, 0, 0, 0, 0, 15872, 15872, - 0, 15872, 15872, 0, 0, 0, 0, 15872, - 15872, 0, 15872, 15872, 0, 0, 0, 0, - 15872, 15872, 0, 15872, 15872, 0, 0, 0, - 0, 15872, 15872, 0, 15872, 15872, 0, 0, - 0, 0, 15872, 15872, 0, 15872, 15872, 0, - 0, 0, 0, 15872, 15872, 0, 15872, 15872, - 0, 0, 0, 0, 15872, 15872, 0, 15872, - 15872, 0, 0, 0, 0, 15872, 15872, 0, - 15872, 15872, 0, 0, 0, 15872, 15872, 15872, - 15872, 15872, 15872, 0, 0, 0, 15872, 15872, - 15872, 15872, 15872, 15872, 0, 0, 0, 15872, - 15872, 15872, 15872, 15872, 15872, 0, 0, 0, - 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, - 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, - 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, - 0, 0, 0, 15872, 15872, 15872, 15872, 15872, - 15872, 0, 0, 0, 15872, 15872, 15872, 15872, - 15872, 15872, 0, 0, 0, 15872, 15872, 15872, - 15872, 15872, 15872, 0, 0, 0, 15872, 15872, - 15872, 15872, 15872, 15872, 0, 0, 0, 15872, - 15872, 15872, 15872, 15872, 15872, 0, 0, 0, - 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, - 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, - 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, - 0, 0, 0, 15872, 15872, 15872, 15872, 15872, - 15872, 0, 0, 0, 15872, 15872, 15872, 15872, - 15872, 15872, 0, 0, 0, 15872, 15872, 15872, - 15872, 15872, 15872, 0, 0, 0, 15872, 15872, - 15872, 15872, 15872, 15872, 0, 0, 0, 15872, - 15872, 15872, 15872, 15872, 15872, 0, 0, 0, - 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, - 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, - 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, - 0, 0, 0, 15872, 15872, 15872, 15872, 15872, - 15872, 0, 0, 0, 15872, 15872, 15872, 15872, - 15872, 15872, 0, 0, 0, 15872, 15872, 15872, - 15872, 15872, 15872, 0, 0, 0, 15872, 15872, - 15872, 15872, 15872, 15872, 0, 0, 0, 15872, - 15872, 15872, 15872, 15872, 15872, 0, 0, 0, - 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, - 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, - 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, - 0, 0, 0, 15872, 15872, 15872, 15872, 15872, - 15872, 0, 0, 0, 15872, 15872, 15872, 15872, - 15872, 15872, 0, 0, 0, 15872, 15872, 15872, - 15872, 15872, 15872, 0, 0, 0, 15872, 15872, - 15872, 15872, 15872, 15872, 0, 0, 0, 15872, - 15872, 15872, 15872, 15872, 15872, 0, 0, 0, - 15872, 15872, 15872, 15872, 15872, 15872, 0, 0, - 0, 15872, 15872, 15872, 15872, 15872, 15872, 0, - 0, 0, 15872, 15872, 15872, 15872, 15872, 15872, - 0, 0, 0, 15872, 15872, 15872, 15872, 15872, - 15872, 0, 0, 0, 15872, 15872, 15872, 15872, - 15872, 15872, 0, 0, 0, 15872, 15872, 15872, - 15872, 15872, 15872, 0, 0, 0, 15872, 15872, - 15872, 15872, 15872, 15872, 0, 0, 0, 15872, - 15872, 15872, 15872, 15872, 15872, 0, 0, 0 - }; - - printf("\nPerforming test:\n"); + + printf("\nPerforming test:\n"); bool passed = true; - for(int i = 0; i < n_conv2d_test; i++) { - if( - im2col_data[i] != expected_im2col[i]) { - passed = false; - break; - } - } - - printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); - - passed = true; - for(int i = 0; i < n_conv2d_test; i++) { - if(conv2d_data[i] != expected_conv2d[i]) { - passed = false; - break; - } + // for(int i = 0; i < ggml_nelements(wino_res); i++) { + for(int i = 0; i < 3*28; i++) { + float diff = fabs(conv2d_data[i] - wino_data[i]); + // if(diff > 1.e-4) { + printf("(%f, %f, %f, %d) \n", + conv2d_data[i], + wino_data[i], diff, i); + // break; + // } } - printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + ggml_free(model.ctx); From 893ca79e1e91d750fc0ddc1e8fa7ba2f774be2a5 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 27 Sep 2024 14:27:55 -0400 Subject: [PATCH 06/23] remove debugging printouts --- src/ggml-cuda/conv-winograd.cu | 8 +++---- tests/test-conv2d.cpp | 42 ++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index 60587db76..d47530f7e 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -752,9 +752,9 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int64_t out_w = in_w; int smem_size = (16*BN*BC + 16*BC*BK)*4; - printf("A %d, %d\n", filt_k, filt_c); - printf("B %d, %d, %d \n", in_c, in_h, in_w); - printf("C %d, %d, %d \n", out_c, out_h, out_w); + // printf("A %d, %d\n", filt_k, filt_c); + // printf("B %d, %d, %d \n", in_c, in_h, in_w); + // printf("C %d, %d, %d \n", out_c, out_h, out_w); Winograd_kernel<<>>(src1, src0, dst, tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); @@ -822,7 +822,7 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * cudaMemcpyToSymbol(access_f_s, aux, 64*sizeof(int)); cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int)); cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int)); - printf(" %d, %d, %d \n", tiles_dim_w, tiles_dim_h, tile_size); + // printf(" %d, %d, %d \n", tiles_dim_w, tiles_dim_h, tile_size); conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8, tile_size, tile_2d_s, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 371277399..1b4078396 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -36,8 +36,8 @@ struct test_model { void load_model(test_model & model, bool use_gpu = false) { // create data - int KW = 3, KH = 3, IC = 10, OC = 10; - int IW = 8, IH = 6, N = 1; + int KW = 3, KH = 3, IC = 32, OC = 32; + int IW = 28, IH = 40, N = 1; // Initialize adata std::vector adata(KW * KH * IC * OC); @@ -365,26 +365,28 @@ int main(void) printf("\nPerforming test:\n"); - bool passed = true; - for(int i = 0; i < n_conv2d_test; i++) { - if( - im2col_data[i] != expected_im2col[i]) { - passed = false; - break; + // bool passed = true; + // for(int i = 0; i < n_conv2d_test; i++) { + // if( + // im2col_data[i] != expected_im2col[i]) { + // passed = false; + // break; + // } + // } + + // printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + // passed = true; + // printf("["); + for(int j = 0; j < 4; j++) { + printf("["); + for(int i = 0; i < 28; i++) { + printf("%.1f, ", conv2d_data[i]); } + printf("]\n"); } - - printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); - - passed = true; - for(int i = 0; i < n_conv2d_test; i++) { - if(conv2d_data[i] != expected_conv2d[i]) { - passed = false; - break; - } - } - - printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + // printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); ggml_free(model.ctx); From 6afbf6e1c4d6e487b2e792394c8cc490601d270a Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 28 Sep 2024 11:43:45 -0400 Subject: [PATCH 07/23] added a FP16 FX kernel to deal with fp16 filter data; no need to use FP32 buffer --- src/ggml-cuda/conv-winograd.cu | 172 ++++++++++++++++++++++++++------- 1 file changed, 135 insertions(+), 37 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index d47530f7e..76353357a 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -356,32 +356,32 @@ float4 *input_frag_mem, float4* filter_frag_mem){ // Set of functions per row in Gw product -__device__ float f_row1(float *Gw, int j){ - return Gw[j]; +__device__ float f_row1(float *G, int j){ + return G[j]; } - __device__ float f_row2(float *Gw, int j){ - return 0.5*(Gw[j] + Gw[6+j] + Gw[3+j]); + __device__ float f_row2(float *G, int j){ + return 0.5*(G[j] + G[6+j] + G[3+j]); } - __device__ float f_row3(float *Gw, int j){ - return 0.5*(Gw[j] + Gw[6+j] - Gw[3+j]); + __device__ float f_row3(float *G, int j){ + return 0.5*(G[j] + G[6+j] - G[3+j]); } - __device__ float f_row4(float *Gw, int j){ - return Gw[6+j]; + __device__ float f_row4(float *G, int j){ + return G[6+j]; } // Set of functions per column in GwGt product - __device__ float f_col1(float *Gw, int j){ - return Gw[j]; + __device__ float f_col1(float *G, int j){ + return G[j]; } - __device__ float f_col2(float *Gw, int j){ - return 0.5*(Gw[j] + Gw[j+2] + Gw[j+1]); + __device__ float f_col2(float *G, int j){ + return 0.5*(G[j] + G[j+2] + G[j+1]); } - __device__ float f_col3(float *Gw, int j){ - return 0.5*(Gw[j] + Gw[j+2] - Gw[j+1]); + __device__ float f_col3(float *G, int j){ + return 0.5*(G[j] + G[j+2] - G[j+1]); } - __device__ float f_col4(float *Gw, int j){ - return Gw[j+2]; + __device__ float f_col4(float *G, int j){ + return G[j+2]; } - + typedef float(*pointFunction_t)(float *, int); __global__ void FX(const float *pInputs, float *pOutputs, int filt_k, @@ -403,9 +403,78 @@ __device__ float f_row1(float *Gw, int j){ pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; for(int bk=0; bk>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1); + FX<<>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1); + +} + +static void conv_winograd_stage0_f16_f32_cuda( + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const half * src0, float * dst, + cudaStream_t stream) { + + + int64_t filt_k = src0_ne0; + int64_t filt_c = src0_ne3; + + FX_FP16<<>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1); } @@ -756,38 +846,45 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, // printf("B %d, %d, %d \n", in_c, in_h, in_w); // printf("C %d, %d, %d \n", out_c, out_h, out_w); - Winograd_kernel<<>>(src1, src0, dst, + Winograd_kernel<<>>(src1, src0, dst, tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); } void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - // const half * src0_d = (const float *)src0->data; + // const half * src0_d = (const half *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - int id = ggml_cuda_get_device(); + // int id = ggml_cuda_get_device(); // GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32); - ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); - if (src0->type != GGML_TYPE_F32) { - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); - GGML_ASSERT(to_fp32_cuda != nullptr); - int64_t nle = ggml_nelements(src0); - src0_ddq_as_f32.alloc(nle); - const half * src0_dd = (const half *)src0->data; - to_fp32_cuda(src0_dd, src0_ddq_as_f32.get(), nle, stream); + // ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); + // if (src0->type != GGML_TYPE_F32) { + // const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + // GGML_ASSERT(to_fp32_cuda != nullptr); + // int64_t nle = ggml_nelements(src0); + // src0_ddq_as_f32.alloc(nle); + // const char * src0_dd = (char *)src0->data; + // to_fp32_cuda(src0_dd, src0_ddq_as_f32.get(), nle, stream); + // } + + // // GGML_ASSERT(ggml_is_contiguous(src0)); + // const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get(); + if(src0->type == GGML_TYPE_F32){ + const float* src0_d = (const float *)src0->data; + conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, dst_d, stream); + }else{ + const half * src0_d = (const half *)src0->data; + conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, dst_d, stream); } - - // GGML_ASSERT(ggml_is_contiguous(src0)); - const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get(); - - conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - src0_ddf_i, dst_d, stream); } @@ -822,13 +919,14 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * cudaMemcpyToSymbol(access_f_s, aux, 64*sizeof(int)); cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int)); cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int)); - // printf(" %d, %d, %d \n", tiles_dim_w, tiles_dim_h, tile_size); + conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8, tile_size, tile_2d_s, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0_d, src1_d, dst_d, stream); + } From 0491858ac21041406cac1fefc30274d6812fd2a8 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 28 Sep 2024 20:15:52 -0400 Subject: [PATCH 08/23] sync from sd.cpp --- src/ggml-cuda/conv-winograd.cu | 157 ++++++++------------------------- src/ggml.c | 6 +- 2 files changed, 41 insertions(+), 122 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index 76353357a..d34be84c7 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -166,8 +166,8 @@ __device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag accumulator[1][15].w += input_frag[3].w*filter_frag[3].w; } -extern "C" -{ +// extern "C" +// { __device__ __forceinline__ void transform_output_tile(float *pOutputs, float2 *C_tile, float2 *At, int round, int c_tensor, int c_glb_offset, int i1, int i2, @@ -248,7 +248,7 @@ float4 *input_frag_mem, float4* filter_frag_mem){ float2 *output_smem = (float2 *) shared_mem; float2 *accumulator = (float2 *) acumm_smem; - float2 *C_out = (float2*)C; + // float2 *C_out = (float2*)C; float2 *C_tile = (float2*) input_frag_mem; float2 *At = (float2*) filter_frag_mem; @@ -295,12 +295,11 @@ float4 *input_frag_mem, float4* filter_frag_mem){ // blockIdx.x*BN + (threadIdx.x%16)*2+ // ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; - int tx = TW, ty = TH; // int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; // int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + // threadIdx.y*(in_h*in_w) - (in_w+1); - int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * tx + blockIdx.y * out_w * ty + + int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * TW + blockIdx.y * out_w * TH + // (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 + ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; @@ -382,77 +381,31 @@ __device__ float f_row1(float *G, int j){ return G[j+2]; } - typedef float(*pointFunction_t)(float *, int); - - __global__ void FX(const float *pInputs, float *pOutputs, int filt_k, - int filt_c, int filt_h, int filt_w){ + template + static __device__ __forceinline__ float t2f32(T val) { + return (float) val; + } - // assumes CHWK layout - int Inx = threadIdx.x, Iny = threadIdx.y; - int TileX = blockIdx.x, TileY = blockIdx.y; - - int c_glb_offset = filt_k*filt_h*filt_w; - int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx; - int c_glb_offset_s = filt_k*4*4; - int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx; - - float Gw[21]; //9+12. In registers - float *Gw_buffer = Gw+9; - - pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4}; - pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; - - for(int bk=0; bk + __device__ float __forceinline__ t2f32(half val) { + return __half2float(val); } - __global__ void FX_FP16(const half *pInputs, float *pOutputs, int filt_k, + typedef float(*pointFunction_t)(float *, int); + + template + __global__ void FX(const T *pInputs, float *pOutputs, int filt_k, int filt_c, int filt_h, int filt_w){ - // assumes CHWK layout + // assumes KCHW layout int Inx = threadIdx.x, Iny = threadIdx.y; int TileX = blockIdx.x, TileY = blockIdx.y; - int c_glb_offset = filt_k*filt_h*filt_w; - int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx; + // int c_glb_offset = filt_k*filt_h*filt_w; + // int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx; + int c_glb_offset = filt_h*filt_w; + // int c_kernel = TileY*BC*c_glb_offset + TileX*BK*filt_c*c_glb_offset + Iny*c_glb_offset+ Inx*filt_c*c_glb_offset; + int c_kernel = (TileY*BC + (TileX*BK+Inx)*filt_c + Iny)*c_glb_offset; int c_glb_offset_s = filt_k*4*4; int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx; @@ -462,19 +415,11 @@ __device__ float f_row1(float *G, int j){ pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4}; pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; - for(int bk=0; bk +static void conv_winograd_stage0_f32_cuda( const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const float * src0, float * dst, + const T * src0, float * dst, cudaStream_t stream) { - - int64_t filt_k = src0_ne0; - int64_t filt_c = src0_ne3; - - FX<<>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1); - -} - -static void conv_winograd_stage0_f16_f32_cuda( - const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, - const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const half * src0, float * dst, - cudaStream_t stream) { - - - int64_t filt_k = src0_ne0; - int64_t filt_c = src0_ne3; - - FX_FP16<<>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1); + FX<<>>(src0, dst, src0_ne3, src0_ne2, src0_ne1, src0_ne0); } @@ -842,12 +762,9 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int64_t out_w = in_w; int smem_size = (16*BN*BC + 16*BC*BK)*4; - // printf("A %d, %d\n", filt_k, filt_c); - // printf("B %d, %d, %d \n", in_c, in_h, in_w); - // printf("C %d, %d, %d \n", out_c, out_h, out_w); - Winograd_kernel<<>>(src1, src0, dst, - tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); + tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, + filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); } @@ -876,12 +793,14 @@ void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * // const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get(); if(src0->type == GGML_TYPE_F32){ const float* src0_d = (const float *)src0->data; - conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + // conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + conv_winograd_stage0_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0_d, dst_d, stream); }else{ const half * src0_d = (const half *)src0->data; - conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + // conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + conv_winograd_stage0_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0_d, dst_d, stream); } diff --git a/src/ggml.c b/src/ggml.c index 34e0a165c..7b1868bc9 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -7184,7 +7184,7 @@ struct ggml_tensor * ggml_winograd_stage0( is_node = true; } - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], 4, 4, a->ne[3]); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[3], 4, 4, a->ne[2]); result->op = GGML_OP_WINOGRAD_STAGE0; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7228,8 +7228,8 @@ struct ggml_tensor * ggml_conv_2d_3x3( if(a->ne[3] % 64 != 0 || a->ne[2] % 8 != 0) // only works for the number of filters is a multiple of 64 return ggml_conv_2d(ctx, a, b, 1, 1, 1, 1, 1, 1); // and the number of channels is a multiple of 8 - struct ggml_tensor* ra = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 3, 0)); // [N, OC, OH, OW] - struct ggml_tensor* W = ggml_winograd_stage0(ctx, ra); + // struct ggml_tensor* ra = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 3, 0)); // [N, OC, OH, OW] + struct ggml_tensor* W = ggml_winograd_stage0(ctx, a); struct ggml_tensor * result = ggml_winograd_stage1(ctx, W, b); return result; From 93c3da7ab7dfff9bfbf3fb91e5500ae58243d0f4 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 28 Sep 2024 20:37:38 -0400 Subject: [PATCH 09/23] restore test-conv2d.cpp test --- tests/test-conv2d.cpp | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 1b4078396..371277399 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -36,8 +36,8 @@ struct test_model { void load_model(test_model & model, bool use_gpu = false) { // create data - int KW = 3, KH = 3, IC = 32, OC = 32; - int IW = 28, IH = 40, N = 1; + int KW = 3, KH = 3, IC = 10, OC = 10; + int IW = 8, IH = 6, N = 1; // Initialize adata std::vector adata(KW * KH * IC * OC); @@ -365,28 +365,26 @@ int main(void) printf("\nPerforming test:\n"); - // bool passed = true; - // for(int i = 0; i < n_conv2d_test; i++) { - // if( - // im2col_data[i] != expected_im2col[i]) { - // passed = false; - // break; - // } - // } - - // printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); - - // passed = true; - // printf("["); - for(int j = 0; j < 4; j++) { - printf("["); - for(int i = 0; i < 28; i++) { - printf("%.1f, ", conv2d_data[i]); + bool passed = true; + for(int i = 0; i < n_conv2d_test; i++) { + if( + im2col_data[i] != expected_im2col[i]) { + passed = false; + break; } - printf("]\n"); } - - // printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); + + passed = true; + for(int i = 0; i < n_conv2d_test; i++) { + if(conv2d_data[i] != expected_conv2d[i]) { + passed = false; + break; + } + } + + printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m"); ggml_free(model.ctx); From 4e8e0d471644d57fa45cf3530c5067c511348133 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 28 Sep 2024 20:39:33 -0400 Subject: [PATCH 10/23] restore src/CMakeLists.txt --- src/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 83c6cb2ce..cbc349500 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -285,9 +285,9 @@ if (GGML_CUDA) # 61 == integer CUDA intrinsics # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;86") + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;86") + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() From e0e94c485aa143660e9861b0703cfcacc77a4c64 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 29 Sep 2024 17:09:35 -0400 Subject: [PATCH 11/23] change mask to unsigned int; add __restric__ to various pointers --- src/ggml-cuda/conv-winograd.cu | 59 +++++++++++++++++----------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index d34be84c7..578a500f0 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -169,9 +169,9 @@ __device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag // extern "C" // { -__device__ __forceinline__ void transform_output_tile(float *pOutputs, float2 *C_tile, float2 *At, +__device__ __forceinline__ void transform_output_tile(float * __restrict__ pOutputs, float2 *C_tile, float2 *At, int round, int c_tensor, int c_glb_offset, int i1, int i2, - unsigned short mask1, unsigned short mask2, int out_w) + unsigned int mask1, unsigned int mask2, int out_w) { c_tensor += (((round)/2)*32 + ((round)%2)*2)*c_glb_offset; @@ -208,10 +208,10 @@ __device__ __forceinline__ void transform_output_tile(float *pOutputs, float2 * } } -__device__ __forceinline__ unsigned short get_mask(int idd, int tiles_dim_w, int tiles_dim_h, +__device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int tiles_dim_h, int tw, int th, int out_w, int out_h){ - unsigned short mask = 0x000F; + unsigned int mask = 0x000F; // if((blockIdx.y/tiles_dim)==(tiles_dim-1) && out_w%2) mask&=0x0003; // pad bottom row // if(!((blockIdx.y+1)%tiles_dim) && out_w%2) mask&=0X0005; // pad right col // if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row @@ -242,7 +242,7 @@ __device__ __forceinline__ unsigned short get_mask(int idd, int tiles_dim_w, int return mask; } -__device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float *C, +__device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float * __restrict__ C, int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, float4 *input_frag_mem, float4* filter_frag_mem){ @@ -271,8 +271,8 @@ float4 *input_frag_mem, float4* filter_frag_mem){ int id2 = (idd2 % tw) * 2 + (idd2 / tw) * out_w * 2; // unsigned short mask1 = 0x000F; - unsigned short mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); - unsigned short mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); // output transpose step int t=0; @@ -355,29 +355,29 @@ float4 *input_frag_mem, float4* filter_frag_mem){ // Set of functions per row in Gw product -__device__ float f_row1(float *G, int j){ +__device__ float f_row1(float * __restrict__ G, int j){ return G[j]; } - __device__ float f_row2(float *G, int j){ - return 0.5*(G[j] + G[6+j] + G[3+j]); + __device__ float f_row2(float * __restrict__ G, int j){ + return 0.5f*(G[j] + G[6+j] + G[3+j]); } - __device__ float f_row3(float *G, int j){ - return 0.5*(G[j] + G[6+j] - G[3+j]); + __device__ float f_row3(float * __restrict__ G, int j){ + return 0.5f*(G[j] + G[6+j] - G[3+j]); } - __device__ float f_row4(float *G, int j){ + __device__ float f_row4(float * __restrict__ G, int j){ return G[6+j]; } // Set of functions per column in GwGt product - __device__ float f_col1(float *G, int j){ + __device__ float f_col1(float * __restrict__ G, int j){ return G[j]; } - __device__ float f_col2(float *G, int j){ - return 0.5*(G[j] + G[j+2] + G[j+1]); + __device__ float f_col2(float * __restrict__ G, int j){ + return 0.5f*(G[j] + G[j+2] + G[j+1]); } - __device__ float f_col3(float *G, int j){ - return 0.5*(G[j] + G[j+2] - G[j+1]); + __device__ float f_col3(float * __restrict__ G, int j){ + return 0.5f*(G[j] + G[j+2] - G[j+1]); } - __device__ float f_col4(float *G, int j){ + __device__ float f_col4(float * __restrict__ G, int j){ return G[j+2]; } @@ -394,10 +394,10 @@ __device__ float f_row1(float *G, int j){ typedef float(*pointFunction_t)(float *, int); template - __global__ void FX(const T *pInputs, float *pOutputs, int filt_k, + __global__ void FX(const T * __restrict__ pInputs, float * __restrict__ pOutputs, int filt_k, int filt_c, int filt_h, int filt_w){ - // assumes KCHW layout + // assumes KCHW layout int Inx = threadIdx.x, Iny = threadIdx.y; int TileX = blockIdx.x, TileY = blockIdx.y; @@ -418,7 +418,6 @@ __device__ float f_row1(float *G, int j){ for(int bk=0; bk -static void conv_winograd_stage0_f32_cuda( +static void conv_winograd_stage0_f32_cuda( const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, const T * src0, float * dst, @@ -764,7 +763,7 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, Winograd_kernel<<>>(src1, src0, dst, tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, - filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); + filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); } From c8700ca9e8a3db3b10b4afe2c0e5bbe6fbb1c89f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 29 Sep 2024 17:10:11 -0400 Subject: [PATCH 12/23] fix indentation --- src/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index c70f5b70c..4ec75d04e 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -2332,7 +2332,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CONV_TRANSPOSE_1D: ggml_cuda_op_conv_transpose_1d(ctx,dst); break; - case GGML_OP_WINOGRAD_STAGE0: + case GGML_OP_WINOGRAD_STAGE0: ggml_cuda_op_winograd_stage0(ctx, dst); break; case GGML_OP_WINOGRAD_STAGE1: From c5d43a2dfa25c4731967da67c3a3187631623ded Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 30 Sep 2024 20:41:42 -0400 Subject: [PATCH 13/23] skip if already computed in a preprocess step --- src/ggml-cuda/conv-winograd.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index 578a500f0..fb40822a8 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -770,7 +770,12 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // const half * src0_d = (const half *)src0->data; - + // in case this tensor has already been computed in a preprocessing step, + // skip this time; + if(src0 == NULL){ + return; + } + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); // int id = ggml_cuda_get_device(); From 00ad37ee68f8bbbbcbecb820683d6e849f62bbfd Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 30 Sep 2024 22:02:53 -0400 Subject: [PATCH 14/23] added winograd conv2d to backend op tests --- tests/test-backend-ops.cpp | 49 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9a96cfc4c..678940090 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2246,6 +2246,51 @@ struct test_im2col : public test_case { } }; +// GGML_Conv2D +struct test_conv2d : public test_case { + const ggml_type type_input; + const ggml_type type_kernel; + const ggml_type dst_type; + const std::array ne_input; + const std::array ne_kernel; + // stride + const int s0; + const int s1; + // padding + const int p0; + const int p1; + // dilation + const int d0; + const int d1; + // mode + + std::string vars() override { + return VARS_TO_STR11(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1); + } + + test_conv2d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32, + std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] + std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] + int s0 = 1, int s1 = 1, + int p0 = 1, int p1 = 1, + int d0 = 1, int d1 = 1) + : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1) + {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); + ggml_set_name(input, "input"); + + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + + ggml_tensor * out = ggml_conv_2d_3x3(ctx, kernel, input); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_CONCAT struct test_concat : public test_case { const ggml_type type; @@ -3252,6 +3297,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); + test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 640, 1}, {3, 3, 640, 960}, 1, 1, 1, 1, 1, 1)); + test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 1280}, 1, 1, 1, 1, 1, 1)); + test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 2560}, 1, 1, 1, 1, 1, 1)); + // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) From 4f93d671bfea6b6461cc972515ef6ed1c4c02cac Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 2 Oct 2024 10:58:44 -0400 Subject: [PATCH 15/23] add conv2d as a test case in test-backend-op --- tests/test-backend-ops.cpp | 73 ++++++++++++++++++++++++++++++++++---- 1 file changed, 67 insertions(+), 6 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 678940090..12cac76a1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -644,7 +644,7 @@ struct test_case { fflush(stdout); // check if backends support op - if (!ggml_backend_supports_op(backend, out)) { + if (!ggml_backend_supports_op(backend, out) && strcmp(op_name, "CONV2D")) { printf("not supported\n"); ggml_free(ctx); return true; @@ -2246,7 +2246,7 @@ struct test_im2col : public test_case { } }; -// GGML_Conv2D + struct test_conv2d : public test_case { const ggml_type type_input; const ggml_type type_kernel; @@ -2263,12 +2263,63 @@ struct test_conv2d : public test_case { const int d0; const int d1; // mode + const bool is_2D; std::string vars() override { return VARS_TO_STR11(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1); } + std::string op_desc(ggml_tensor * t) override { + return std::string("CONV2D"); + } + + test_conv2d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32, + std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] + std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] + int s0 = 1, int s1 = 1, + int p0 = 1, int p1 = 1, + int d0 = 1, int d1 = 1, + bool is_2D = true ) + : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1) + , is_2D(is_2D){} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); + ggml_set_name(input, "input"); + + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + + ggml_tensor * out = ggml_conv_2d(ctx, kernel, input, 1, 1, 1, 1, 1, 1); + ggml_set_name(out, "out"); + + return out; + } +}; +// GGML_Conv2D +struct test_conv2d_wino : public test_case { + const ggml_type type_input; + const ggml_type type_kernel; + const ggml_type dst_type; + const std::array ne_input; + const std::array ne_kernel; + // stride + const int s0; + const int s1; + // padding + const int p0; + const int p1; + // dilation + const int d0; + const int d1; + // mode + + std::string vars() override { + return VARS_TO_STR11(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1); + } + + test_conv2d_wino(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32, std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] int s0 = 1, int s1 = 1, @@ -3296,10 +3347,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); - - test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 640, 1}, {3, 3, 640, 960}, 1, 1, 1, 1, 1, 1)); - test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 1280}, 1, 1, 1, 1, 1, 1)); - test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 2560}, 1, 1, 1, 1, 1, 1)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 2560}, 1, 1, 1, 1, 1, 1,true)); + + test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 640, 1}, {3, 3, 640, 960}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 1280}, 1, 1, 1, 1, 1, 1,true)); + test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 2560}, 1, 1, 1, 1, 1, 1,true)); + test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {512, 512, 512, 1}, {3, 3, 512, 256}, 1, 1, 1, 1, 1, 1,true)); + // test_cases.emplace_back(new test_conv2d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 2560}, 1, 1, 1, 1, 1, 1,true)); + + test_cases.emplace_back(new test_conv2d_wino(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 640, 1}, {3, 3, 640, 960}, 1, 1, 1, 1, 1, 1)); + test_cases.emplace_back(new test_conv2d_wino(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 1280}, 1, 1, 1, 1, 1, 1)); + test_cases.emplace_back(new test_conv2d_wino(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 320, 1}, {3, 3, 320, 640}, 1, 1, 1, 1, 1, 1)); + test_cases.emplace_back(new test_conv2d_wino(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 2560}, 1, 1, 1, 1, 1, 1)); + test_cases.emplace_back(new test_conv2d_wino(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {56, 80, 1280, 1}, {3, 3, 1280, 2560}, 1, 1, 1, 1, 1, 1)); + test_cases.emplace_back(new test_conv2d_wino(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {512, 512, 512, 1}, {3, 3, 512, 256}, 1, 1, 1, 1, 1, 1)); // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) From 1d606ed2c5c88a477da38131436648b321a64cd3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 18 Dec 2024 08:55:45 -0500 Subject: [PATCH 16/23] further refine super tile size; implement a proper performance comparison --- src/ggml-cuda/conv-winograd.cu | 2 +- src/ggml-cuda/conv-winograd.cuh | 4 +- tests/test-conv2d-winograd.cpp | 187 +++++++++++++++++++++++++++----- 3 files changed, 165 insertions(+), 28 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index fb40822a8..ca89f1657 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -843,7 +843,7 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int)); cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int)); - conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8, + conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, tile_size, tile_2d_s, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], diff --git a/src/ggml-cuda/conv-winograd.cuh b/src/ggml-cuda/conv-winograd.cuh index 39bc7002c..ebc3f48fa 100644 --- a/src/ggml-cuda/conv-winograd.cuh +++ b/src/ggml-cuda/conv-winograd.cuh @@ -4,8 +4,8 @@ #define BC 8 #define BN 32 #define BK 64 -#define TW 8 -#define TH 16 +#define TW 32 +#define TH 4 #define BN_p 138 __constant__ int access_f_s[2][32]; diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp index bd5efa3bd..0c08c6817 100644 --- a/tests/test-conv2d-winograd.cpp +++ b/tests/test-conv2d-winograd.cpp @@ -4,6 +4,7 @@ #ifdef GGML_USE_CUDA #include "ggml-cuda.h" +//#include #endif #ifdef GGML_USE_METAL @@ -36,8 +37,10 @@ struct test_model { void load_model(test_model & model, bool use_gpu = false) { // create data - int KW = 3, KH = 3, IC = 32, OC = 64; - int IW = 28, IH = 40, N = 1; + int KW = 3, KH = 3, IC = 256, OC = 256; + int IW = 832, IH = 1216, N = 1; + + printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); // Initialize adata std::vector adata(KW * KH * IC * OC); @@ -135,7 +138,7 @@ void load_model(test_model & model, bool use_gpu = false) { } } -struct ggml_cgraph * build_graph(const test_model& model) { +struct ggml_cgraph * build_graph_0(const test_model& model) { static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); @@ -163,21 +166,104 @@ struct ggml_cgraph * build_graph(const test_model& model) { struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); ggml_set_name(conv2d_res, "conv2d_res"); ggml_build_forward_expand(gf, conv2d_res); - int64_t *ne = conv2d_res->ne; - printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); + // ggml_set_name(wino_res, "wino_res"); + // ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * build_graph_1(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + // ggml_set_name(conv2d_res, "conv2d_res"); + // ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); ggml_set_name(wino_res, "wino_res"); ggml_build_forward_expand(gf, wino_res); - ne = wino_res->ne; - printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); ggml_free(ctx0); return gf; } -struct ggml_cgraph * compute_graph(const test_model & model, ggml_gallocr_t allocr) { - struct ggml_cgraph * gf = build_graph(model); +struct ggml_cgraph * compute_graph_0(const test_model & model, ggml_gallocr_t allocr) { + struct ggml_cgraph * gf = build_graph_0(model); + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + int iterations = 20; + + ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_synchronize(model.backend); + + int64_t start_time = ggml_time_us(); + + for(int iter=0; iter 1.e-4) { - printf("(%f, %f, %f, %d) \n", - conv2d_data[i], - wino_data[i], diff, i); - // break; - // } - } + // for(int i = 0; i < 3*28; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + // // if(diff > 1.e-4) { + // printf("(%f, %f, %f, %d) \n", + // conv2d_data[i], + // wino_data[i], diff, i); + // // break; + // // } + // } From 864e4a7e21d973e90673b56094e6b020fb799b63 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 25 Dec 2024 17:25:17 -0500 Subject: [PATCH 17/23] fix test case --- tests/test-conv2d-winograd.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp index 0c08c6817..2b0a72e75 100644 --- a/tests/test-conv2d-winograd.cpp +++ b/tests/test-conv2d-winograd.cpp @@ -37,8 +37,8 @@ struct test_model { void load_model(test_model & model, bool use_gpu = false) { // create data - int KW = 3, KH = 3, IC = 256, OC = 256; - int IW = 832, IH = 1216, N = 1; + int KW = 3, KH = 3, IC = 128, OC = 128; + int IW = 64, IH = 96, N = 1; printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); @@ -382,15 +382,15 @@ int main(void) bool passed = true; // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 3*28; i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - // // if(diff > 1.e-4) { - // printf("(%f, %f, %f, %d) \n", - // conv2d_data[i], - // wino_data[i], diff, i); - // // break; - // // } - // } + for(int i = 0; i < 3*28; i++) { + float diff = fabs(conv2d_data[i] - wino_data[i]); + // if(diff > 1.e-4) { + printf("(%f, %f, %f, %d) \n", + conv2d_data[i], + wino_data[i], diff, i); + // break; + // } + } From 2d10e24e6996d300d6c7ab2bb488fcfffde2f826 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 25 Dec 2024 20:50:20 -0500 Subject: [PATCH 18/23] changing #define macros to template arguments, step 1 --- src/ggml-cuda/conv-winograd.cu | 22 ++++++++++++++-------- src/ggml-cuda/conv-winograd.cuh | 4 ++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index ca89f1657..dcf929e03 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -208,6 +208,7 @@ __device__ __forceinline__ void transform_output_tile(float * __restrict__ pOut } } +template __device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int tiles_dim_h, int tw, int th, int out_w, int out_h){ @@ -242,6 +243,7 @@ __device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int t return mask; } +template __device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float * __restrict__ C, int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, float4 *input_frag_mem, float4* filter_frag_mem){ @@ -271,8 +273,8 @@ float4 *input_frag_mem, float4* filter_frag_mem){ int id2 = (idd2 % tw) * 2 + (idd2 / tw) * out_w * 2; // unsigned short mask1 = 0x000F; - unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); - unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); // output transpose step int t=0; @@ -520,6 +522,7 @@ __device__ __forceinline__ void prefetch_filter_tile(const float * __restrict__ } } +template __device__ __forceinline__ void prefetch_input_tile(const float * __restrict__ pInputs, float *tile, int in_h, int in_w, int tw, int th, unsigned short mask){ @@ -585,6 +588,7 @@ __device__ __forceinline__ void prefetch_input_frag(float4* input_frag, float4 *((float4*) (input_frag + 3)) = *(A_frag + frag_offset + offset2); //3=2+1 } +template __global__ void Winograd_kernel(const float *A, const float *B, float *C, int tiles_dim_w, int tiles_dim_h, int in_c, int in_h, int in_w, @@ -647,7 +651,7 @@ __global__ void Winograd_kernel(const float *A, const float *B, float *C, float4 *swap_filter; float4 *swap_input; - prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); prefetch_filter_tile(B, filter_tile, filt_k); float4 *input_frag_buffer = (float4*) (input_frag+4); @@ -697,7 +701,7 @@ __global__ void Winograd_kernel(const float *A, const float *B, float *C, B += filt_k*BC*4*4; if(iter<(in_c-BC)){ - prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); prefetch_filter_tile(B, filter_tile, filt_k); } @@ -705,12 +709,12 @@ __global__ void Winograd_kernel(const float *A, const float *B, float *C, } // Transpose, transform and store accumulated result - store_output_tile(accumulator, shared_mem, C, out_h, out_w, tiles_dim_w, tiles_dim_h, X, Y, + store_output_tile(accumulator, shared_mem, C, out_h, out_w, tiles_dim_w, tiles_dim_h, X, Y, input_frag_mem, filter_frag_mem); } -cudaError_t convolutionForward_32Tx64x8(float *k, int in_h, int in_w, float *w, int out_h, +/*cudaError_t convolutionForward_32Tx64x8(float *k, int in_h, int in_w, float *w, int out_h, int out_w, int out_c, float *C, float *Ww, int tiles_dim_w, int tiles_dim_h, int tile_size, int in_c, int filt_k, int filt_c, int filt_h, int filt_w, int m){ @@ -728,7 +732,7 @@ cudaError_t convolutionForward_32Tx64x8(float *k, int in_h, int in_w, float *w, tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); return cudaGetLastError(); -} +}*/ // } @@ -760,8 +764,10 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int64_t out_h = in_h; int64_t out_w = in_w; int smem_size = (16*BN*BC + 16*BC*BK)*4; + int max_size = 65536; // 64 KB + cudaFuncSetAttribute(Winograd_kernel<32, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); - Winograd_kernel<<>>(src1, src0, dst, + Winograd_kernel<32, 4><<>>(src1, src0, dst, tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); } diff --git a/src/ggml-cuda/conv-winograd.cuh b/src/ggml-cuda/conv-winograd.cuh index ebc3f48fa..7b5c081c2 100644 --- a/src/ggml-cuda/conv-winograd.cuh +++ b/src/ggml-cuda/conv-winograd.cuh @@ -4,8 +4,8 @@ #define BC 8 #define BN 32 #define BK 64 -#define TW 32 -#define TH 4 +// #define TW 32 +// #define TH 4 #define BN_p 138 __constant__ int access_f_s[2][32]; From 86d7aa9811e7768c023c571f6e99ee53c0d4dced Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 26 Dec 2024 12:07:09 -0500 Subject: [PATCH 19/23] sync. winograd impl. from sd.cpp --- src/ggml-cuda/conv-winograd.cu | 462 +++++++++++++++++++++++--------- src/ggml-cuda/conv-winograd.cuh | 6 +- src/ggml.c | 3 +- tests/test-conv2d-winograd.cpp | 20 +- 4 files changed, 349 insertions(+), 142 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index dcf929e03..0a62d8dc0 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -1,6 +1,9 @@ #include "conv-winograd.cuh" #include "convert.cuh" +#include + +#if 0 __device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag, float4 accumulator[][16]){ accumulator[0][0].x += input_frag[0].x*filter_frag[0].x; accumulator[0][0].y += input_frag[0].y*filter_frag[0].x; @@ -165,6 +168,172 @@ __device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag accumulator[1][15].z += input_frag[3].z*filter_frag[3].w; accumulator[1][15].w += input_frag[3].w*filter_frag[3].w; } +#endif + +template +__device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag, float4 accumulator_in[][16]){ + T *accumulator = (T *)accumulator_in; + T *x = (T *)input_frag; + T *y = (T *)filter_frag; + + *(accumulator) += (*(x+0)) * (*y); + *(accumulator+1) += (*(x+1)) * (*(y)); + *(accumulator+2) += (*(x+2)) * (*(y)); + *(accumulator+3) += (*(x+3)) * (*(y)); + + *(accumulator+4) += (*(x+4)) * (*(y)); + *(accumulator+5) += (*(x+5)) * (*(y)); + *(accumulator+6) += (*(x+6)) * (*(y)); + *(accumulator+7) += (*(x+7)) * (*(y)); + + *(accumulator+8) += (*(x+0)) * (*(y+1)); + *(accumulator+9) += (*(x+1)) * (*(y+1)); + *(accumulator+10) += (*(x+2)) * (*(y+1)); + *(accumulator+11) += (*(x+3)) * (*(y+1)); + + *(accumulator+12) += (*(x+4)) * (*(y+1)); + *(accumulator+13) += (*(x+5)) * (*(y+1)); + *(accumulator+14) += (*(x+6)) * (*(y+1)); + *(accumulator+15) += (*(x+7)) * (*(y+1)); + + *(accumulator+16) += (*(x+0)) * (*(y+2)); + *(accumulator+17) += (*(x+1)) * (*(y+2)); + *(accumulator+18) += (*(x+2)) * (*(y+2)); + *(accumulator+19) += (*(x+3)) * (*(y+2)); + + *(accumulator+20) += (*(x+4)) * (*(y+2)); + *(accumulator+21) += (*(x+5)) * (*(y+2)); + *(accumulator+22) += (*(x+6)) * (*(y+2)); + *(accumulator+23) += (*(x+7)) * (*(y+2)); + + *(accumulator+24) += (*(x+0)) * (*(y+3)); + *(accumulator+25) += (*(x+1)) * (*(y+3)); + *(accumulator+26) += (*(x+2)) * (*(y+3)); + *(accumulator+27) += (*(x+3)) * (*(y+3)); + + *(accumulator+28) += (*(x+4)) * (*(y+3)); + *(accumulator+29) += (*(x+5)) * (*(y+3)); + *(accumulator+30) += (*(x+6)) * (*(y+3)); + *(accumulator+31) += (*(x+7)) * (*(y+3)); + + // + *(accumulator+32) += (*(x+0)) * (*(y+4)); + *(accumulator+33) += (*(x+1)) * (*(y+4)); + *(accumulator+34) += (*(x+2)) * (*(y+4)); + *(accumulator+35) += (*(x+3)) * (*(y+4)); + + *(accumulator+36) += (*(x+4)) * (*(y+4)); + *(accumulator+37) += (*(x+5)) * (*(y+4)); + *(accumulator+38) += (*(x+6)) * (*(y+4)); + *(accumulator+39) += (*(x+7)) * (*(y+4)); + + *(accumulator+40) += (*(x+0)) * (*(y+5)); + *(accumulator+41) += (*(x+1)) * (*(y+5)); + *(accumulator+42) += (*(x+2)) * (*(y+5)); + *(accumulator+43) += (*(x+3)) * (*(y+5)); + + *(accumulator+44) += (*(x+4)) * (*(y+5)); + *(accumulator+45) += (*(x+5)) * (*(y+5)); + *(accumulator+46) += (*(x+6)) * (*(y+5)); + *(accumulator+47) += (*(x+7)) * (*(y+5)); + + *(accumulator+48) += (*(x+0)) * (*(y+6)); + *(accumulator+49) += (*(x+1)) * (*(y+6)); + *(accumulator+50) += (*(x+2)) * (*(y+6)); + *(accumulator+51) += (*(x+3)) * (*(y+6)); + + *(accumulator+52) += (*(x+4)) * (*(y+6)); + *(accumulator+53) += (*(x+5)) * (*(y+6)); + *(accumulator+54) += (*(x+6)) * (*(y+6)); + *(accumulator+55) += (*(x+7)) * (*(y+6)); + + *(accumulator+56) += (*(x+0)) * (*(y+7)); + *(accumulator+57) += (*(x+1)) * (*(y+7)); + *(accumulator+58) += (*(x+2)) * (*(y+7)); + *(accumulator+59) += (*(x+3)) * (*(y+7)); + + *(accumulator+60) += (*(x+4)) * (*(y+7)); + *(accumulator+61) += (*(x+5)) * (*(y+7)); + *(accumulator+62) += (*(x+6)) * (*(y+7)); + *(accumulator+63) += (*(x+7)) * (*(y+7)); + + ////// + + *(accumulator+64) += (*(x+8)) * (*(y+8)); + *(accumulator+65) += (*(x+9)) * (*(y+8)); + *(accumulator+66) += (*(x+10)) * (*(y+8)); + *(accumulator+67) += (*(x+11)) * (*(y+8)); + *(accumulator+68) += (*(x+12)) * (*(y+8)); + *(accumulator+69) += (*(x+13)) * (*(y+8)); + *(accumulator+70) += (*(x+14)) * (*(y+8)); + *(accumulator+71) += (*(x+15)) * (*(y+8)); + + *(accumulator+72) += (*(x+8)) * (*(y+9)); + *(accumulator+73) += (*(x+9)) * (*(y+9)); + *(accumulator+74) += (*(x+10)) * (*(y+9)); + *(accumulator+75) += (*(x+11)) * (*(y+9)); + *(accumulator+76) += (*(x+12)) * (*(y+9)); + *(accumulator+77) += (*(x+13)) * (*(y+9)); + *(accumulator+78) += (*(x+14)) * (*(y+9)); + *(accumulator+79) += (*(x+15)) * (*(y+9)); + + *(accumulator+80) += (*(x+8)) * (*(y+10)); + *(accumulator+81) += (*(x+9)) * (*(y+10)); + *(accumulator+82) += (*(x+10)) * (*(y+10)); + *(accumulator+83) += (*(x+11)) * (*(y+10)); + *(accumulator+84) += (*(x+12)) * (*(y+10)); + *(accumulator+85) += (*(x+13)) * (*(y+10)); + *(accumulator+86) += (*(x+14)) * (*(y+10)); + *(accumulator+87) += (*(x+15)) * (*(y+10)); + + *(accumulator+88) += (*(x+8)) * (*(y+11)); + *(accumulator+89) += (*(x+9)) * (*(y+11)); + *(accumulator+90) += (*(x+10)) * (*(y+11)); + *(accumulator+91) += (*(x+11)) * (*(y+11)); + *(accumulator+92) += (*(x+12)) * (*(y+11)); + *(accumulator+93) += (*(x+13)) * (*(y+11)); + *(accumulator+94) += (*(x+14)) * (*(y+11)); + *(accumulator+95) += (*(x+15)) * (*(y+11)); + + // + + *(accumulator+96) += (*(x+8)) * (*(y+12)); + *(accumulator+97) += (*(x+9)) * (*(y+12)); + *(accumulator+98) += (*(x+10)) * (*(y+12)); + *(accumulator+99) += (*(x+11)) * (*(y+12)); + *(accumulator+100) += (*(x+12)) * (*(y+12)); + *(accumulator+101) += (*(x+13)) * (*(y+12)); + *(accumulator+102) += (*(x+14)) * (*(y+12)); + *(accumulator+103) += (*(x+15)) * (*(y+12)); + + *(accumulator+104) += (*(x+8)) * (*(y+13)); + *(accumulator+105) += (*(x+9)) * (*(y+13)); + *(accumulator+106) += (*(x+10)) * (*(y+13)); + *(accumulator+107) += (*(x+11)) * (*(y+13)); + *(accumulator+108) += (*(x+12)) * (*(y+13)); + *(accumulator+109) += (*(x+13)) * (*(y+13)); + *(accumulator+110) += (*(x+14)) * (*(y+13)); + *(accumulator+111) += (*(x+15)) * (*(y+13)); + + *(accumulator+112) += (*(x+8)) * (*(y+14)); + *(accumulator+113) += (*(x+9)) * (*(y+14)); + *(accumulator+114) += (*(x+10)) * (*(y+14)); + *(accumulator+115) += (*(x+11)) * (*(y+14)); + *(accumulator+116) += (*(x+12)) * (*(y+14)); + *(accumulator+117) += (*(x+13)) * (*(y+14)); + *(accumulator+118) += (*(x+14)) * (*(y+14)); + *(accumulator+119) += (*(x+15)) * (*(y+14)); + + *(accumulator+120) += (*(x+8)) * (*(y+15)); + *(accumulator+121) += (*(x+9)) * (*(y+15)); + *(accumulator+122) += (*(x+10)) * (*(y+15)); + *(accumulator+123) += (*(x+11)) * (*(y+15)); + *(accumulator+124) += (*(x+12)) * (*(y+15)); + *(accumulator+125) += (*(x+13)) * (*(y+15)); + *(accumulator+126) += (*(x+14)) * (*(y+15)); + *(accumulator+127) += (*(x+15)) * (*(y+15)); + + } // extern "C" // { @@ -174,7 +343,7 @@ __device__ __forceinline__ void transform_output_tile(float * __restrict__ pOut unsigned int mask1, unsigned int mask2, int out_w) { - c_tensor += (((round)/2)*32 + ((round)%2)*2)*c_glb_offset; + c_tensor += (((round)>>1)*32 + ((round)&1)*2)*c_glb_offset; int x, x1; #pragma unroll @@ -191,7 +360,7 @@ __device__ __forceinline__ void transform_output_tile(float * __restrict__ pOut #pragma unroll for(int i=0; i<2; i++){ x = i*4; - x1 = i*((out_w-(out_w%2)) + (out_w%2)/2); + x1 = i*((out_w-(out_w&1)) + (out_w&1)/2); if(mask1&(1<<(i*2))){ pOutputs[x1 + c_tensor + i1] = At[x].x + At[x+1].x + At[x+2].x; @@ -208,7 +377,6 @@ __device__ __forceinline__ void transform_output_tile(float * __restrict__ pOut } } -template __device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int tiles_dim_h, int tw, int th, int out_w, int out_h){ @@ -217,33 +385,32 @@ __device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int t // if(!((blockIdx.y+1)%tiles_dim) && out_w%2) mask&=0X0005; // pad right col // if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row // if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col - if(tiles_dim_w % tw == 0 && tiles_dim_h % th == 0){ - if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && out_h%2) mask&=0x0003; // pad bottom row - if(blockIdx.x==gridDim.x-1 && (idd % tw) == tw-1 && out_w%2) mask&=0X0005; // pad right col - }else if(tiles_dim_w % tw == 0){ - int k = out_h % TH; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles - if(blockIdx.y==gridDim.y-1 && (idd / tw) == k1-1 && k%2) mask&=0x0003; // pad bottom row + if((tiles_dim_w & (tw-1)) == 0 && (tiles_dim_h & (th-1)) == 0){ + if(blockIdx.y==gridDim.y-1 && (idd / tw) == th-1 && (out_h&1)) mask&=0x0003; // pad bottom row + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) == tw-1 && (out_w&1)) mask&=0X0005; // pad right col + }else if((tiles_dim_w & (tw-1)) == 0){ + int k = out_h & (TH-1); + int k1 = k & 1 ? (k+1)>>1 : k>>1; // there could be 4*k1 tiles + if(blockIdx.y==gridDim.y-1 && (idd / tw) == k1-1 && (k&1)) mask&=0x0003; // pad bottom row if(blockIdx.y==gridDim.y-1 && (idd / tw) > k1-1) mask &= 0x0; //pad all zeros since this tile does not exist - }else if(tiles_dim_h % th == 0){ - int k = out_w % TW; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles - if(blockIdx.x==gridDim.x-1 && (idd % tw) == k1-1 && k%2) mask&=0X0005; // pad right col - if(blockIdx.x==gridDim.x-1 && (idd % tw) > k1-1) mask&=0X0; // pad all zeroes + }else if((tiles_dim_h & (th-1)) == 0){ + int k = out_w & (TW-1); + int k1 = k & 1 ? (k+1) >> 1 : k >> 1; // there could be 4*k1 tiles + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) == k1-1 && (k&1)) mask&=0X0005; // pad right col + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) > k1-1) mask&=0X0; // pad all zeroes }else{ - int kh = out_h % TH; - int kw = out_w % TW; - int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles - int kw1 = kw % 2 ? (kw+1)/2 : kw/2; - if(blockIdx.y==gridDim.y-1 && (idd / tw) == kh1-1 && kh%2) mask&=0x0003; // pad bottom row - if(blockIdx.x==gridDim.x-1 && (idd % tw) == kw1-1 && kw%2) mask&=0X0005; // pad right col + int kh = out_h & (TH-1); + int kw = out_w & (TW-1); + int kh1 = kh & 1 ? (kh+1) >> 1 : kh >> 1; // there could be kh1*kw1 tiles + int kw1 = kw & 1 ? (kw+1) >> 1 : kw >> 1; + if(blockIdx.y==gridDim.y-1 && (idd / tw) == kh1-1 && (kh&1)) mask&=0x0003; // pad bottom row + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) == kw1-1 && (kw&1)) mask&=0X0005; // pad right col if(blockIdx.y==gridDim.y-1 && (idd / tw) > kh1-1) mask &= 0x0; //pad all zeros since this tile does not exist - if(blockIdx.x==gridDim.x-1 && (idd % tw) > kw1-1) mask &= 0X0; // pad all zeroes + if(blockIdx.x==gridDim.x-1 && (idd & (tw-1)) > kw1-1) mask &= 0X0; // pad all zeroes } return mask; } -template __device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float * __restrict__ C, int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, float4 *input_frag_mem, float4* filter_frag_mem){ @@ -268,20 +435,20 @@ float4 *input_frag_mem, float4* filter_frag_mem){ // for 2nd tile int idd1 = tileid[0][threadIdx.x]; - int id1 = (idd1 % tw) * 2 + (idd1 / tw) * out_w * 2; + int id1 = (idd1 & (tw-1)) * 2 + (idd1 / tw) * out_w * 2; int idd2 = tileid[1][threadIdx.x]; - int id2 = (idd2 % tw) * 2 + (idd2 / tw) * out_w * 2; + int id2 = (idd2 & (tw-1)) * 2 + (idd2 / tw) * out_w * 2; // unsigned short mask1 = 0x000F; - unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); - unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); // output transpose step int t=0; int acumm1, acumm2; // For transposing //acumm1 = access_s_out[Inx]; //* 4 - acumm1 = ((threadIdx.x%8)/2)*34 + threadIdx.x%2 + (threadIdx.x/16)*2 + ((threadIdx.x/8)%2)*8; + acumm1 = ((threadIdx.x&7)>>1)*34 + (threadIdx.x&1) + (threadIdx.x>>4)*2 + ((threadIdx.x>>3)&1)*8; acumm2 = acumm1+4; int acumm4 = BN_p*16 ; //*4 @@ -290,7 +457,7 @@ float4 *input_frag_mem, float4* filter_frag_mem){ // For transformating int offset = BN_p *2; //*2/2 - int init = ( (threadIdx.y/4)*BN_p*16 + (threadIdx.y%4)*(32+2) ) *2 + threadIdx.x; + int init = ( (threadIdx.y>>2)*BN_p*16 + (threadIdx.y&3)*(32+2) ) *2 + threadIdx.x; int c_glb_offset = out_h*out_w; // int c_tensor = blockIdx.z*c_glb_offset*BK + (blockIdx.y%tiles_dim)*2 + (blockIdx.y/tiles_dim)*out_w*2 + @@ -303,7 +470,7 @@ float4 *input_frag_mem, float4* filter_frag_mem){ int c_tensor = blockIdx.z*c_glb_offset*BK + blockIdx.x * TW + blockIdx.y * out_w * TH + // (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 + - ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset; + ((threadIdx.x>>4)*16 + (threadIdx.y&3)*4 + (threadIdx.y>>2))*c_glb_offset; #pragma unroll for(int round=0; round<4; round++){ @@ -355,31 +522,48 @@ float4 *input_frag_mem, float4* filter_frag_mem){ } } +template + static __device__ T __forceinline__ fx_const (float val) { + return static_cast(val); + } +template <> + __device__ half __forceinline__ fx_const(float val) { + return __float2half(val); + } // Set of functions per row in Gw product -__device__ float f_row1(float * __restrict__ G, int j){ +template + __device__ T f_row1(const T * __restrict__ G, int j){ return G[j]; } - __device__ float f_row2(float * __restrict__ G, int j){ - return 0.5f*(G[j] + G[6+j] + G[3+j]); +template + __device__ T f_row2(const T * __restrict__ G, int j){ + return fx_const(0.5f)*(G[j] + G[6+j] + G[3+j]); } - __device__ float f_row3(float * __restrict__ G, int j){ - return 0.5f*(G[j] + G[6+j] - G[3+j]); +template + __device__ T f_row3(const T * __restrict__ G, int j){ + return fx_const(0.5f)*(G[j] + G[6+j] - G[3+j]); } - __device__ float f_row4(float * __restrict__ G, int j){ +template + __device__ T f_row4(const T * __restrict__ G, int j){ return G[6+j]; } + // Set of functions per column in GwGt product - __device__ float f_col1(float * __restrict__ G, int j){ +template + __device__ T f_col1(const T * __restrict__ G, int j){ return G[j]; } - __device__ float f_col2(float * __restrict__ G, int j){ - return 0.5f*(G[j] + G[j+2] + G[j+1]); +template + __device__ T f_col2(const T * __restrict__ G, int j){ + return fx_const(0.5f)*(G[j] + G[j+2] + G[j+1]); } - __device__ float f_col3(float * __restrict__ G, int j){ - return 0.5f*(G[j] + G[j+2] - G[j+1]); +template + __device__ T f_col3(const T * __restrict__ G, int j){ + return fx_const(0.5f)*(G[j] + G[j+2] - G[j+1]); } - __device__ float f_col4(float * __restrict__ G, int j){ +template + __device__ T f_col4(const T * __restrict__ G, int j){ return G[j+2]; } @@ -393,12 +577,14 @@ __device__ float f_row1(float * __restrict__ G, int j){ return __half2float(val); } - typedef float(*pointFunction_t)(float *, int); + template - __global__ void FX(const T * __restrict__ pInputs, float * __restrict__ pOutputs, int filt_k, + __global__ void FX(const T * __restrict__ pInputs, T * __restrict__ pOutputs, int filt_k, int filt_c, int filt_h, int filt_w){ + typedef T(*pointFunction_t)(const T *, int); + // assumes KCHW layout int Inx = threadIdx.x, Iny = threadIdx.y; int TileX = blockIdx.x, TileY = blockIdx.y; @@ -411,15 +597,16 @@ __device__ float f_row1(float * __restrict__ G, int j){ int c_glb_offset_s = filt_k*4*4; int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx; - float Gw[21]; //9+12. In registers - float *Gw_buffer = Gw+9; + T Gw[21]; //9+12. In registers + T *Gw_buffer = Gw+9; pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4}; pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4}; for(int bk=0; bk +__device__ __forceinline__ void prefetch_filter_tile(const T * __restrict__ pInputs, float * __restrict__ tiles, int filt_k){ - int c_tensor = blockIdx.z*BK + (threadIdx.y*filt_k<<4) + threadIdx.x; // Iny*filt_k*4*4 + int c_tensor = blockIdx.z*BK + threadIdx.y*(filt_k<<4) + threadIdx.x; // Iny*filt_k*4*4 // each threadIdx.y corresponds to one channel; there are 8 different threadIdx.y so 8 channels //each thread (32 threads in x direction) loads 2 kernel tiles (32 in K direction apart) @@ -513,24 +701,23 @@ __device__ __forceinline__ void prefetch_filter_tile(const float * __restrict__ int acumm; #pragma unroll for(int i=0; i<4; i++){ - acumm = (i*filt_k<<2); + acumm = i*(filt_k<<2); #pragma unroll for(int j=0; j<4; j++){ - tiles[(i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor]; - tiles[16 + (i<<2) + j] = pInputs[acumm + j*filt_k + c_tensor+BN]; + tiles[(i<<2) + j] = t2f32(pInputs[acumm + j*filt_k + c_tensor]); + tiles[16 + (i<<2) + j] = t2f32(pInputs[acumm + j*filt_k + c_tensor+BN]); } } } -template __device__ __forceinline__ void prefetch_input_tile(const float * __restrict__ pInputs, float *tile, int in_h, int in_w, int tw, int th, unsigned short mask){ - // load one input tile - int tx = TW, ty = TH; - int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty; - int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 + - threadIdx.y*(in_h*in_w) - (in_w+1); + // each thread loads two input tiles to fill a half2 buffer + int c_offset = in_h*in_w; + int c_tile = blockIdx.x * TW + blockIdx.y * in_w * TH; + int c_tensor = c_tile + ((threadIdx.x & (tw-1)) << 1) + (threadIdx.x / tw ) * (in_w << 1) + + threadIdx.y*c_offset - (in_w+1); int acumm,x; @@ -541,7 +728,7 @@ __device__ __forceinline__ void prefetch_input_tile(const float * __restrict__ p acumm = i*in_w; #pragma unroll for(int j=0; j<4; j++){ - tile[(i<<2) + j] = pInputs[acumm + j + c_tensor]; + tile[(i<<2) + j] = pInputs[acumm + j + c_tensor]; } } @@ -553,7 +740,7 @@ __device__ __forceinline__ void prefetch_input_tile(const float * __restrict__ p x = (i<<2) + j; tile[x] = 0.f; if(mask&(1< -__global__ void Winograd_kernel(const float *A, const float *B, float *C, +template +__global__ void Winograd_kernel(const float *A, const T *B, float *C, int tiles_dim_w, int tiles_dim_h, int in_c, int in_h, int in_w, int tile_size, int X, int Y, @@ -604,32 +791,32 @@ __global__ void Winograd_kernel(const float *A, const float *B, float *C, unsigned int m = 0xFFFF; if(blockIdx.y==0 && (threadIdx.x / X) == 0) m &= 0xFFF0; // pad top row - if(tiles_dim_w % X == 0 && tiles_dim_h % Y == 0){ - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows - if(blockIdx.x==gridDim.x-1 && (threadIdx.x % X) == X-1) m &= (!(in_w%2))?(0x7777):(0x3333); // pad right col or right 2 cols - }else if(tiles_dim_w % X == 0){ - int k = in_h % TH; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 4*k1 tiles - if(blockIdx.x==gridDim.x-1 && (threadIdx.x % X) == X-1) m &= (!(in_w%2))?(0x7777):(0x3333); // pad right col or right 2 cols - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == k1-1) m &= (!(k%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(tiles_dim_w & (X-1) == 0 && (tiles_dim_h & (Y-1)) == 0){ + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h&1))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) == X-1) m &= (!(in_w&1))?(0x7777):(0x3333); // pad right col or right 2 cols + }else if((tiles_dim_w & (X-1)) == 0){ + int k = in_h & (TH-1); + int k1 = k & 1 ? (k+1)>>1 : (k>>1); // there could be 4*k1 tiles + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) == X-1) m &= (!(in_w&1))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == k1-1) m &= (!(k&1))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows if(blockIdx.y==gridDim.y-1 && threadIdx.x / X > k1-1) m &= 0x0; //pad all zeros since this tile does not exist - }else if(tiles_dim_h % Y == 0){ - int k = in_w % TW; - int k1 = k % 2 ? (k+1)/2 : k/2; // there could be 8*k1 tiles - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == k1-1) m &= (!(k%2))?(0x7777):(0x3333); // pad right col or right 2 cols - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > k1-1) m &= 0x0; //pad all zeros since this tile does not exist + }else if((tiles_dim_h & (Y-1)) == 0){ + int k = in_w & (TW-1); + int k1 = k & 1 ? (k+1)>>1 : k>>1; // there could be 8*k1 tiles + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == Y-1) m &= (!(in_h&1))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) == k1-1) m &= (!(k&1))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) > k1-1) m &= 0x0; //pad all zeros since this tile does not exist }else{ - int kh = in_h % TH; - int kw = in_w % TW; - int kh1 = kh % 2 ? (kh+1)/2 : kh/2; // there could be kh1*kw1 tiles - int kw1 = kw % 2 ? (kw+1)/2 : kw/2; - if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == kh1-1) m &= (!(kh%2))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows + int kh = in_h & (TH-1); + int kw = in_w & (TW-1); + int kh1 = kh & 1 ? (kh+1)>>1 : kh>>1; // there could be kh1*kw1 tiles + int kw1 = kw & 1 ? (kw+1)>>1 : kw>>1; + if(blockIdx.y==gridDim.y-1 && threadIdx.x / X == kh1-1) m &= (!(kh&1))?(0x0FFF):(0x00FF); //pad bottom row or bottom 2 rows if(blockIdx.y==gridDim.y-1 && threadIdx.x / X > kh1-1) m &= 0x0; //pad all zeros since this tile does not exist - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X == kw1-1) m &= (!(kw%2))?(0x7777):(0x3333); // pad right col or right 2 cols - if(blockIdx.x==gridDim.x-1 && threadIdx.x % X > kw1-1) m &= 0x0; //pad all zeros since this tile does not exist + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) == kw1-1) m &= (!(kw&1))?(0x7777):(0x3333); // pad right col or right 2 cols + if(blockIdx.x==gridDim.x-1 && (threadIdx.x & (X-1)) > kw1-1) m &= 0x0; //pad all zeros since this tile does not exist } - if(blockIdx.x==0 && (threadIdx.x % X) == 0) m &=0xeeee; // pad left col + if(blockIdx.x==0 && (threadIdx.x & (X-1)) == 0) m &=0xeeee; // pad left col float img_tile[16]; // Prefetch input from GMEM float filter_tile[32]; // Prefetch filter from GMEM @@ -651,7 +838,7 @@ __global__ void Winograd_kernel(const float *A, const float *B, float *C, float4 *swap_filter; float4 *swap_input; - prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); prefetch_filter_tile(B, filter_tile, filt_k); float4 *input_frag_buffer = (float4*) (input_frag+4); @@ -685,7 +872,7 @@ __global__ void Winograd_kernel(const float *A, const float *B, float *C, prefetch_filter_frag(filter_frag_buffer, B_frag, f_frag_offset, access_f_s[0][threadIdx.x], access_f_s[1][threadIdx.x]); } - outer_product(input_frag, filter_frag, accumulator); + outer_product(input_frag, filter_frag, accumulator); swap_input = input_frag; input_frag = input_frag_buffer; @@ -701,7 +888,7 @@ __global__ void Winograd_kernel(const float *A, const float *B, float *C, B += filt_k*BC*4*4; if(iter<(in_c-BC)){ - prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); prefetch_filter_tile(B, filter_tile, filt_k); } @@ -709,44 +896,50 @@ __global__ void Winograd_kernel(const float *A, const float *B, float *C, } // Transpose, transform and store accumulated result - store_output_tile(accumulator, shared_mem, C, out_h, out_w, tiles_dim_w, tiles_dim_h, X, Y, + store_output_tile(accumulator, shared_mem, C, out_h, out_w, tiles_dim_w, tiles_dim_h, X, Y, input_frag_mem, filter_frag_mem); } -/*cudaError_t convolutionForward_32Tx64x8(float *k, int in_h, int in_w, float *w, int out_h, - int out_w, int out_c, float *C, float *Ww, - int tiles_dim_w, int tiles_dim_h, int tile_size, - int in_c, int filt_k, int filt_c, int filt_h, int filt_w, int m){ - - int tile_2d_s = tile_size*tile_size; - int smem_size = (16*BN*BC + 16*BC*BK)*4; - int X = 4, Y = 8; - - - FX<<>>(w, Ww, filt_k, filt_c, filt_h, filt_w); - - // each thread block will load 32 tiles (4x4) from the single image input - // we let X*Y = 32 and arbitraraly pick X = 4 and Y = 8 - Winograd_kernel<<>>(k, Ww, C, - tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); - - return cudaGetLastError(); -}*/ // } template -static void conv_winograd_stage0_f32_cuda( +static void conv_winograd_stage0_cuda( const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const T * src0, float * dst, + const T * src0, T * dst, cudaStream_t stream) { - - FX<<>>(src0, dst, src0_ne3, src0_ne2, src0_ne1, src0_ne0); + // printf("doing FX\n"); + FX<<>>(src0, dst, src0_ne3, src0_ne2, src0_ne1, src0_ne0); } +static void conv_winograd_stage1_f16_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, + int tile_size, int tile_2d_s, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const half * src0, const float * src1, float * dst, + cudaStream_t stream) { + + int64_t filt_k = src0_ne0; + int64_t in_c = src1_ne2; + int64_t in_h = src1_ne1; + int64_t in_w = src1_ne0; + int64_t filt_c = src0_ne3; + int64_t out_c = filt_k; + int64_t out_h = in_h; + int64_t out_w = in_w; + int smem_size = (16*BN*BC + 16*BC*BK)*4; + int max_size = 65536; // 64 KB + cudaFuncSetAttribute(Winograd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); + + Winograd_kernel<<>>(src1, src0, dst, + tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, + filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); +} + static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, int tile_size, int tile_2d_s, const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, @@ -765,9 +958,9 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int64_t out_w = in_w; int smem_size = (16*BN*BC + 16*BC*BK)*4; int max_size = 65536; // 64 KB - cudaFuncSetAttribute(Winograd_kernel<32, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); + cudaFuncSetAttribute(Winograd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); - Winograd_kernel<32, 4><<>>(src1, src0, dst, + Winograd_kernel<<>>(src1, src0, dst, tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); } @@ -776,18 +969,15 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // const half * src0_d = (const half *)src0->data; - // in case this tensor has already been computed in a preprocessing step, - // skip this time; if(src0 == NULL){ return; } - - float * dst_d = (float *)dst->data; + // float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); // int id = ggml_cuda_get_device(); // GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + // GGML_ASSERT( dst->type == GGML_TYPE_F32); // ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); // if (src0->type != GGML_TYPE_F32) { @@ -803,14 +993,16 @@ void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * // const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get(); if(src0->type == GGML_TYPE_F32){ const float* src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; // conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - conv_winograd_stage0_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + conv_winograd_stage0_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0_d, dst_d, stream); }else{ const half * src0_d = (const half *)src0->data; + half * dst_d = (half *)dst->data; // conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - conv_winograd_stage0_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + conv_winograd_stage0_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0_d, dst_d, stream); } @@ -820,7 +1012,7 @@ void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; + // const float * src0_d = (const float *)src0->data; const ggml_tensor * src1 = dst->src[1]; const float * src1_d = (const float *)src1->data; @@ -828,7 +1020,7 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32); + // GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -849,12 +1041,26 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int)); cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int)); - conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, - tile_size, tile_2d_s, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - src0_d, src1_d, dst_d, stream); + if(src0->type == GGML_TYPE_F32){ + const float * src0_d = (const float *)src0->data; + // const float * src1_d = (const float *)src1->data; + conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, + tile_size, tile_2d_s, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); + } else{ + const half * src0_d = (const half *)src0->data; + // const half * src1_d = (const half *)src1->data; + conv_winograd_stage1_f16_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, + tile_size, tile_2d_s, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); + } + } diff --git a/src/ggml-cuda/conv-winograd.cuh b/src/ggml-cuda/conv-winograd.cuh index 7b5c081c2..a4b85e7fd 100644 --- a/src/ggml-cuda/conv-winograd.cuh +++ b/src/ggml-cuda/conv-winograd.cuh @@ -1,11 +1,11 @@ #include "common.cuh" - +#include #define BC 8 #define BN 32 #define BK 64 -// #define TW 32 -// #define TH 4 +#define TW 32 +#define TH 4 #define BN_p 138 __constant__ int access_f_s[2][32]; diff --git a/src/ggml.c b/src/ggml.c index 7b1868bc9..ee5879fb0 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -7184,7 +7184,8 @@ struct ggml_tensor * ggml_winograd_stage0( is_node = true; } - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[3], 4, 4, a->ne[2]); + // struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[3], 4, 4, a->ne[2]); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[3], 4, 4, a->ne[2]); result->op = GGML_OP_WINOGRAD_STAGE0; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp index 2b0a72e75..e8e069549 100644 --- a/tests/test-conv2d-winograd.cpp +++ b/tests/test-conv2d-winograd.cpp @@ -37,7 +37,7 @@ struct test_model { void load_model(test_model & model, bool use_gpu = false) { // create data - int KW = 3, KH = 3, IC = 128, OC = 128; + int KW = 3, KH = 3, IC = 64, OC = 64; int IW = 64, IH = 96, N = 1; printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); @@ -382,15 +382,15 @@ int main(void) bool passed = true; // for(int i = 0; i < ggml_nelements(wino_res); i++) { - for(int i = 0; i < 3*28; i++) { - float diff = fabs(conv2d_data[i] - wino_data[i]); - // if(diff > 1.e-4) { - printf("(%f, %f, %f, %d) \n", - conv2d_data[i], - wino_data[i], diff, i); - // break; - // } - } + // for(int i = 0; i < 3*28; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + // // if(diff > 1.e-4) { + // printf("(%f, %f, %f, %d) \n", + // conv2d_data[i], + // wino_data[i], diff, i); + // // break; + // // } + // } From 67ec2857731f8c01b3b92c18498e34c425f9f48d Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 29 Dec 2024 21:13:31 -0500 Subject: [PATCH 20/23] test performance for configs from sd.cpp --- tests/test-conv2d-winograd.cpp | 160 +++++++++++++++++++-------------- 1 file changed, 95 insertions(+), 65 deletions(-) diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp index e8e069549..b5a7b2782 100644 --- a/tests/test-conv2d-winograd.cpp +++ b/tests/test-conv2d-winograd.cpp @@ -35,12 +35,14 @@ struct test_model { struct ggml_context * ctx; }; -void load_model(test_model & model, bool use_gpu = false) { + + +void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu = false ) { // create data - int KW = 3, KH = 3, IC = 64, OC = 64; - int IW = 64, IH = 96, N = 1; + int KW = 3, KH = 3, IC = ic, OC = oc; + int IW = iw, IH = ih, N = 1; - printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); + // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); // Initialize adata std::vector adata(KW * KH * IC * OC); @@ -65,8 +67,8 @@ void load_model(test_model & model, bool use_gpu = false) { buffer_size += 1024; // overhead } - printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); - printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); int num_tensors = 2; struct ggml_init_params params { @@ -78,7 +80,7 @@ void load_model(test_model & model, bool use_gpu = false) { // initialize the backend #ifdef GGML_USE_CUDA if (use_gpu) { - fprintf(stderr, "%s: using CUDA backend\n", __func__); + // fprintf(stderr, "%s: using CUDA backend\n", __func__); model.backend = ggml_backend_cuda_init(0); if (!model.backend) { fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); @@ -220,7 +222,7 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { return gf; } -struct ggml_cgraph * compute_graph_0(const test_model & model, ggml_gallocr_t allocr) { +double compute_graph_0(const test_model & model, ggml_gallocr_t allocr, int iters) { struct ggml_cgraph * gf = build_graph_0(model); // allocate tensors @@ -237,7 +239,7 @@ struct ggml_cgraph * compute_graph_0(const test_model & model, ggml_gallocr_t al } #endif - int iterations = 20; + ggml_backend_graph_compute(model.backend, gf); @@ -245,7 +247,7 @@ struct ggml_cgraph * compute_graph_0(const test_model & model, ggml_gallocr_t al int64_t start_time = ggml_time_us(); - for(int iter=0; iter> configs = { + std::make_tuple(64,64,48,64), + std::make_tuple(320,320,104,152), + std::make_tuple(640,640,52,76), + std::make_tuple(640,640,104,152), + std::make_tuple(960,320,104,152), + std::make_tuple(1280,1280,26,38), + std::make_tuple(1280,640,52,76), + std::make_tuple(1920,1280,26,38), + std::make_tuple(2560,1280,26,38), + std::make_tuple(512,512,104,152), + std::make_tuple(512,512,208,304), + std::make_tuple(512,256,416,608), + std::make_tuple(256,128,832,1216), + std::make_tuple(256,256,832,1216) + }; - test_model model; - load_model(model, true); + int k = 0; - ggml_gallocr_t allocr = NULL; + for (auto c : configs){ + test_model model; + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), true); - { + ggml_gallocr_t allocr = NULL; allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); //create the worst case graph for memory usage estimation @@ -320,65 +338,81 @@ int main(void) // compute the required memory ggml_gallocr_reserve(allocr, gf); - size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0); - fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - } - - struct ggml_cgraph * gf_res_0 = NULL; + size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + - gf_res_0 = compute_graph_0(model, allocr); + struct ggml_cgraph * gf_res_0 = NULL; + int iterations = 20; + double run_time0 = compute_graph_0(model, allocr, iterations); - // ggml_gallocr_t allocr = NULL; - { + //allocr = NULL; + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); //create the worst case graph for memory usage estimation - struct ggml_cgraph * gf = build_graph_1(model); + gf = build_graph_1(model); // compute the required memory ggml_gallocr_reserve(allocr, gf); - size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0); - fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - } + size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + - struct ggml_cgraph * gf_res_1 = NULL; + struct ggml_cgraph * gf_res_1 = NULL; - gf_res_1 = compute_graph_1(model, allocr); + double run_time1 = compute_graph_1(model, allocr, iterations); - + if(k==0) { + k = 1; + fprintf(stderr, "| (IC, OC, IW, IH) | im2col TIME | im2col VRAM | wino TIME | wino VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- \n"); + } + + fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), + run_time0, mem_size0/1024.0f/1024.0f, + run_time1, mem_size1/1024.0f/1024.0f); + + ggml_free(model.ctx); + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + + } - struct ggml_tensor * wino_res = NULL; - struct ggml_tensor * conv2d_res = NULL; + // struct ggml_tensor * wino_res = NULL; + // struct ggml_tensor * conv2d_res = NULL; - for(int i = 0; i < ggml_graph_n_nodes(gf_res_0); ++i) { - if(strcmp(ggml_get_name(ggml_graph_node(gf_res_0, i)), "wino_res") == 0) { - wino_res = ggml_graph_node(gf_res_0, i); - } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res_0, i)), "conv2d_res") == 0) { - conv2d_res = ggml_graph_node(gf_res_0, i); - } - } + // for(int i = 0; i < ggml_graph_n_nodes(gf_res_0); ++i) { + // if(strcmp(ggml_get_name(ggml_graph_node(gf_res_0, i)), "wino_res") == 0) { + // wino_res = ggml_graph_node(gf_res_0, i); + // } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res_0, i)), "conv2d_res") == 0) { + // conv2d_res = ggml_graph_node(gf_res_0, i); + // } + // } - for(int i = 0; i < ggml_graph_n_nodes(gf_res_1); ++i) { - if(strcmp(ggml_get_name(ggml_graph_node(gf_res_1, i)), "wino_res") == 0) { - wino_res = ggml_graph_node(gf_res_1, i); - } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res_1, i)), "conv2d_res") == 0) { - conv2d_res = ggml_graph_node(gf_res_1, i); - } - } + // for(int i = 0; i < ggml_graph_n_nodes(gf_res_1); ++i) { + // if(strcmp(ggml_get_name(ggml_graph_node(gf_res_1, i)), "wino_res") == 0) { + // wino_res = ggml_graph_node(gf_res_1, i); + // } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res_1, i)), "conv2d_res") == 0) { + // conv2d_res = ggml_graph_node(gf_res_1, i); + // } + // } - std::vector wino_data(ggml_nelements(wino_res)); - std::vector conv2d_data(ggml_nelements(conv2d_res)); + // std::vector wino_data(ggml_nelements(wino_res)); + // std::vector conv2d_data(ggml_nelements(conv2d_res)); - ggml_backend_tensor_get(wino_res, wino_data.data(), 0, ggml_nbytes(wino_res)); - ggml_backend_tensor_get(conv2d_res, conv2d_data.data(), 0, ggml_nbytes(conv2d_res)); + // ggml_backend_tensor_get(wino_res, wino_data.data(), 0, ggml_nbytes(wino_res)); + // ggml_backend_tensor_get(conv2d_res, conv2d_data.data(), 0, ggml_nbytes(conv2d_res)); - printf("\nPerforming test:\n"); + // printf("\nPerforming test:\n"); bool passed = true; // for(int i = 0; i < ggml_nelements(wino_res); i++) { @@ -394,10 +428,6 @@ int main(void) - ggml_free(model.ctx); - - ggml_backend_buffer_free(model.buffer); - ggml_backend_free(model.backend); - ggml_gallocr_free(allocr); + return 0; } From d7103e316d1608684386943b7f4a40f870f0499e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 1 Jan 2025 11:11:00 -0500 Subject: [PATCH 21/23] refactor test-conv2d-winograd --- tests/test-conv2d-winograd.cpp | 123 ++++++++++----------------------- 1 file changed, 38 insertions(+), 85 deletions(-) diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp index b5a7b2782..3be4f369d 100644 --- a/tests/test-conv2d-winograd.cpp +++ b/tests/test-conv2d-winograd.cpp @@ -140,6 +140,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu } } +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); + struct ggml_cgraph * build_graph_0(const test_model& model) { static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); @@ -222,8 +224,13 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { return gf; } -double compute_graph_0(const test_model & model, ggml_gallocr_t allocr, int iters) { - struct ggml_cgraph * gf = build_graph_0(model); + + + +std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, + build_graph_t build_graph, int iters, double *t) { + struct ggml_cgraph * gf = build_graph(model); + // allocate tensors ggml_gallocr_alloc_graph(allocr, gf); @@ -260,50 +267,25 @@ double compute_graph_0(const test_model & model, ggml_gallocr_t allocr, int iter //ggml_graph_print(gf); - return time_us/1000; -} - - -double compute_graph_1(const test_model & model, ggml_gallocr_t allocr, int iters) { - struct ggml_cgraph * gf = build_graph_1(model); - - // allocate tensors - ggml_gallocr_alloc_graph(allocr, gf); - int n_threads = 1; - - if (ggml_backend_is_cpu(model.backend)) { - ggml_backend_cpu_set_n_threads(model.backend, n_threads); - } - -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(model.backend)) { - ggml_backend_metal_set_n_cb(model.backend, n_threads); - } -#endif - - - ggml_backend_graph_compute(model.backend, gf); - - ggml_backend_synchronize(model.backend); + struct ggml_tensor *res = NULL; - int64_t start_time = ggml_time_us(); - - for(int iter=0; iter data(ggml_nelements(res)); + ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res)); - //ggml_graph_print(gf); + *t = time_us/1000; + return data; - return time_us/1000; } + int main(void) { ggml_time_init(); @@ -345,7 +327,8 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; int iterations = 20; - double run_time0 = compute_graph_0(model, allocr, iterations); + double run_time0; + std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); //allocr = NULL; @@ -363,11 +346,12 @@ int main(void) struct ggml_cgraph * gf_res_1 = NULL; - double run_time1 = compute_graph_1(model, allocr, iterations); + double run_time1; + std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); if(k==0) { k = 1; - fprintf(stderr, "| (IC, OC, IW, IH) | im2col TIME | im2col VRAM | wino TIME | wino VRAM \n"); + fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | Winograd TIME | Winograd VRAM \n"); fprintf(stderr, "| --- | --- | --- | --- | --- \n"); } @@ -376,6 +360,18 @@ int main(void) run_time0, mem_size0/1024.0f/1024.0f, run_time1, mem_size1/1024.0f/1024.0f); + + // for(int i = 0; i < ggml_nelements(wino_res); i++) { + // for(int i = 0; i < 3*28; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + // // if(diff > 1.e-4) { + // printf("(%f, %f, %f, %d) \n", + // conv2d_data[i], + // wino_data[i], diff, i); + // // break; + // // } + // } + ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); ggml_backend_free(model.backend); @@ -383,51 +379,8 @@ int main(void) } - - - - // struct ggml_tensor * wino_res = NULL; - // struct ggml_tensor * conv2d_res = NULL; - - // for(int i = 0; i < ggml_graph_n_nodes(gf_res_0); ++i) { - // if(strcmp(ggml_get_name(ggml_graph_node(gf_res_0, i)), "wino_res") == 0) { - // wino_res = ggml_graph_node(gf_res_0, i); - // } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res_0, i)), "conv2d_res") == 0) { - // conv2d_res = ggml_graph_node(gf_res_0, i); - // } - // } - - // for(int i = 0; i < ggml_graph_n_nodes(gf_res_1); ++i) { - // if(strcmp(ggml_get_name(ggml_graph_node(gf_res_1, i)), "wino_res") == 0) { - // wino_res = ggml_graph_node(gf_res_1, i); - // } else if(strcmp(ggml_get_name(ggml_graph_node(gf_res_1, i)), "conv2d_res") == 0) { - // conv2d_res = ggml_graph_node(gf_res_1, i); - // } - // } - - // std::vector wino_data(ggml_nelements(wino_res)); - // std::vector conv2d_data(ggml_nelements(conv2d_res)); - - // ggml_backend_tensor_get(wino_res, wino_data.data(), 0, ggml_nbytes(wino_res)); - // ggml_backend_tensor_get(conv2d_res, conv2d_data.data(), 0, ggml_nbytes(conv2d_res)); - // printf("\nPerforming test:\n"); - - bool passed = true; - // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 3*28; i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - // // if(diff > 1.e-4) { - // printf("(%f, %f, %f, %d) \n", - // conv2d_data[i], - // wino_data[i], diff, i); - // // break; - // // } - // } - - - return 0; } From 3dbd324b882189365b6399129f2d641768dc5a36 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 5 Jan 2025 15:38:34 -0500 Subject: [PATCH 22/23] add the missing ggml_gallocr_free(allocr) --- tests/test-conv2d-winograd.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test-conv2d-winograd.cpp b/tests/test-conv2d-winograd.cpp index 3be4f369d..b113a7f51 100644 --- a/tests/test-conv2d-winograd.cpp +++ b/tests/test-conv2d-winograd.cpp @@ -303,7 +303,8 @@ int main(void) std::make_tuple(512,512,208,304), std::make_tuple(512,256,416,608), std::make_tuple(256,128,832,1216), - std::make_tuple(256,256,832,1216) + std::make_tuple(256,256,832,1216), + std::make_tuple(320,256,1024,1920) }; int k = 0; @@ -330,8 +331,9 @@ int main(void) double run_time0; std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); + ggml_gallocr_free(allocr); - //allocr = NULL; + allocr = NULL; allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); From bf845d58172857a07096ed2fca251b2afa9da401 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 6 Jan 2025 15:01:45 -0500 Subject: [PATCH 23/23] get rid of #defines by using template arguments --- src/ggml-cuda/conv-winograd.cu | 58 ++++++++++----------------------- src/ggml-cuda/conv-winograd.cuh | 4 +-- 2 files changed, 20 insertions(+), 42 deletions(-) diff --git a/src/ggml-cuda/conv-winograd.cu b/src/ggml-cuda/conv-winograd.cu index 0a62d8dc0..6b4c53a1a 100644 --- a/src/ggml-cuda/conv-winograd.cu +++ b/src/ggml-cuda/conv-winograd.cu @@ -377,6 +377,7 @@ __device__ __forceinline__ void transform_output_tile(float * __restrict__ pOut } } +template __device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int tiles_dim_h, int tw, int th, int out_w, int out_h){ @@ -411,6 +412,7 @@ __device__ __forceinline__ unsigned int get_mask(int idd, int tiles_dim_w, int t return mask; } +template __device__ __forceinline__ void store_output_tile(float4 acumm_smem[][16], float *shared_mem, float * __restrict__ C, int out_h, int out_w, int tiles_dim_w, int tiles_dim_h, int tw, int th, float4 *input_frag_mem, float4* filter_frag_mem){ @@ -440,8 +442,8 @@ float4 *input_frag_mem, float4* filter_frag_mem){ int id2 = (idd2 & (tw-1)) * 2 + (idd2 / tw) * out_w * 2; // unsigned short mask1 = 0x000F; - unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); - unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask1 = get_mask(idd1, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); + unsigned int mask2 = get_mask(idd2, tiles_dim_w, tiles_dim_h, tw, th, out_w, out_h); // output transpose step int t=0; @@ -710,6 +712,7 @@ __device__ __forceinline__ void prefetch_filter_tile(const T * __restrict__ pInp } } +template __device__ __forceinline__ void prefetch_input_tile(const float * __restrict__ pInputs, float *tile, int in_h, int in_w, int tw, int th, unsigned short mask){ @@ -775,7 +778,7 @@ __device__ __forceinline__ void prefetch_input_frag(float4* input_frag, float4 *((float4*) (input_frag + 3)) = *(A_frag + frag_offset + offset2); //3=2+1 } -template +template __global__ void Winograd_kernel(const float *A, const T *B, float *C, int tiles_dim_w, int tiles_dim_h, int in_c, int in_h, int in_w, @@ -838,7 +841,7 @@ __global__ void Winograd_kernel(const float *A, const T *B, float *C, float4 *swap_filter; float4 *swap_input; - prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); prefetch_filter_tile(B, filter_tile, filt_k); float4 *input_frag_buffer = (float4*) (input_frag+4); @@ -888,7 +891,7 @@ __global__ void Winograd_kernel(const float *A, const T *B, float *C, B += filt_k*BC*4*4; if(iter<(in_c-BC)){ - prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); + prefetch_input_tile(A, img_tile, in_h, in_w, X, Y, m); prefetch_filter_tile(B, filter_tile, filt_k); } @@ -896,7 +899,7 @@ __global__ void Winograd_kernel(const float *A, const T *B, float *C, } // Transpose, transform and store accumulated result - store_output_tile(accumulator, shared_mem, C, out_h, out_w, tiles_dim_w, tiles_dim_h, X, Y, + store_output_tile(accumulator, shared_mem, C, out_h, out_w, tiles_dim_w, tiles_dim_h, X, Y, input_frag_mem, filter_frag_mem); } @@ -915,37 +918,13 @@ static void conv_winograd_stage0_cuda( } -static void conv_winograd_stage1_f16_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, - int tile_size, int tile_2d_s, - const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, - const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, - const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const half * src0, const float * src1, float * dst, - cudaStream_t stream) { - - int64_t filt_k = src0_ne0; - int64_t in_c = src1_ne2; - int64_t in_h = src1_ne1; - int64_t in_w = src1_ne0; - int64_t filt_c = src0_ne3; - int64_t out_c = filt_k; - int64_t out_h = in_h; - int64_t out_w = in_w; - int smem_size = (16*BN*BC + 16*BC*BK)*4; - int max_size = 65536; // 64 KB - cudaFuncSetAttribute(Winograd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); - - Winograd_kernel<<>>(src1, src0, dst, - tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, - filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); -} - -static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, +template +static void conv_winograd_stage1_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y, int tile_size, int tile_2d_s, const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const float * src0, const float * src1, float * dst, + const T * src0, const float * src1, float * dst, cudaStream_t stream) { int64_t filt_k = src0_ne0; @@ -958,14 +937,13 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int64_t out_w = in_w; int smem_size = (16*BN*BC + 16*BC*BK)*4; int max_size = 65536; // 64 KB - cudaFuncSetAttribute(Winograd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); + cudaFuncSetAttribute(Winograd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_size); - Winograd_kernel<<>>(src1, src0, dst, + Winograd_kernel<<>>(src1, src0, dst, tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w); } - void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // const half * src0_d = (const half *)src0->data; @@ -1043,8 +1021,8 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * if(src0->type == GGML_TYPE_F32){ const float * src0_d = (const float *)src0->data; - // const float * src1_d = (const float *)src1->data; - conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, + // const float * src1_d = (const float *)src1->data; + conv_winograd_stage1_cuda(tiles_dim_w, tiles_dim_h, 16, 2, tile_size, tile_2d_s, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], @@ -1052,8 +1030,8 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor * src0_d, src1_d, dst_d, stream); } else{ const half * src0_d = (const half *)src0->data; - // const half * src1_d = (const half *)src1->data; - conv_winograd_stage1_f16_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2, + // const half * src1_d = (const half *)src1->data; + conv_winograd_stage1_cuda(tiles_dim_w, tiles_dim_h, 16, 2, tile_size, tile_2d_s, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], diff --git a/src/ggml-cuda/conv-winograd.cuh b/src/ggml-cuda/conv-winograd.cuh index a4b85e7fd..626c7fa7b 100644 --- a/src/ggml-cuda/conv-winograd.cuh +++ b/src/ggml-cuda/conv-winograd.cuh @@ -4,8 +4,8 @@ #define BC 8 #define BN 32 #define BK 64 -#define TW 32 -#define TH 4 +// #define TW 32 +// #define TH 4 #define BN_p 138 __constant__ int access_f_s[2][32];